From b24a544ec9f947d58594d080455d692c98c786e1 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 22 Jan 2024 17:24:41 +0800 Subject: [PATCH 01/27] [dnn] Add dnn.linear taichi implmentation --- brainpy/_src/dnn/linear.py | 708 ++++++++++++++++++++++++++ brainpy/_src/dnn/tests/test_linear.py | 127 +++++ brainpy/dnn/linear.py | 8 + 3 files changed, 843 insertions(+) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 09bf2958d..b837dd920 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import numba import numpy as np +import taichi as ti from brainpy import math as bm from brainpy._src import connect, initialize as init @@ -29,6 +30,9 @@ 'CSRLinear', 'EventCSRLinear', 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear', 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear', + 'CSRLinear_taichi', 'EventCSRLinear_taichi', + 'JitFPHomoLinear_taichi', 'JitFPUniformLinear_taichi', 'JitFPNormalLinear_taichi', + 'EventJitFPHomoLinear_taichi', 'EventJitFPNormalLinear_taichi', 'EventJitFPUniformLinear_taichi', ] @@ -660,9 +664,11 @@ def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) + csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update) + def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf @@ -1278,3 +1284,705 @@ def _batch_mv(self, x): shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) + +### TAICHI CUSTOMIZED OPERATOR IMPLEMENTATION ### + +class _CSRLayer_taichi(Layer, SupportSTDP): + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + assert sharding is None, 'Currently this model does not support sharding.' + self.conn = conn + self.sharding = sharding + self.transpose = transpose + + # connection + self.indices, self.indptr = self.conn.require('csr') + + # weight + weight = init.parameter(weight, (self.indices.size,)) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if bm.isscalar(self.weight): + raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') + if self.weight.shape != self.indices.shape: + raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: # update on presynaptic spike + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = csr_on_pre_update_taichi(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) + if on_post is not None: # update on postsynaptic spike + if not hasattr(self, '_pre_ids'): + with jax.ensure_compile_time_eval(): + self._pre_ids, self._post_indptr, self.w_indices = csr2csc( + [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) + ) + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = csc_on_post_update_taichi(self.weight.value, self._pre_ids, self._post_indptr, + self.w_indices, spike, trace, w_min, w_max) + + +class CSRLinear_taichi(_CSRLayer_taichi): + r"""Synaptic matrix multiplication with CSR sparse computation(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + + def update(self, x): + if x.ndim == 1: + return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose)[0] + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose)[0] + +class EventCSRLinear_taichi(_CSRLayer_taichi): + r"""Synaptic matrix multiplication with event CSR sparse computation(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + + def update(self, x): + if x.ndim == 1: + return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose)[0] + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose)[0] + +@ti.kernel +def _cpu_csr_on_pre_update_taichi(w: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + w_value = w[0] + out_w[:] = w_value + w_min_value = w_min[0] + w_max_value = w_max[0] + for i in range(spike.shape[0]): # pre id + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): # synapse id + j = indices[k] # post id + out_w[k] = ti.min(ti.max(out_w[k] + trace[j], w_min_value), w_max_value) + +csr_on_pre_update_prim_taichi = bm.XLACustomOp(_cpu_csr_on_pre_update_taichi) + + +def csr_on_pre_update_taichi(w, indices, indptr, spike, trace, w_min=None, w_max=None): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w = jax.Array(w) + w_min = jax.Array(w_min) + w_max = jax.Array(w_max) + return csr_on_pre_update_prim_taichi(w, indices, indptr, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + +@ti.kernel +def _cpu_csc_on_pre_update_taichi(w: ti.types.ndarray(ndim=1), + post_ids: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + w_ids: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + w_value = w[0] + out_w[:] = w_value + w_min_value = w_min[0] + w_max_value = w_max[0] + + for i in range(spike.shape[0]): # post id + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = post_ids[k] # pre id + l = w_ids[k] # syn id + out_w[l] = ti.min(ti.max(out_w[l] + trace[j], w_min_value), w_max_value) + +csc_on_pre_update_prim_taichi = bm.XLACustomOp(_cpu_csc_on_pre_update_taichi) + + +def csc_on_post_update_taichi(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w = jax.Array(w) + w_min = jax.Array(w_min) + w_max = jax.Array(w_max) + return csc_on_pre_update_prim_taichi(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + +class JitFPHomoLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPUniformLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPNormalLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPHomoLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 1000000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPUniformLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPNormalLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index da49bdbfe..48fc60a01 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -213,6 +213,133 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): self.assertTrue(y2.shape == shape + (200,)) bm.clear_buffer_memory() + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_CSRLinear_taichi(self, conn): + bm.random.seed() + f = bp.dnn.CSRLinear_taichi(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_EventCSRLinear_taichi(self,conn): + bm.random.seed() + f=bp.layers.EventCSRLinear_taichi(conn,weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPHomoLinear_taichi(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.JitFPHomoLinear_taichi(100, 200, prob, weight, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPUniformLinear_taichi(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.JitFPUniformLinear_taichi(100, 200, prob, w_low, w_high, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPNormalLinear_taichi(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.JitFPNormalLinear_taichi(100, 200, prob, w_mu, w_sigma, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPHomoLinear_taichi(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.EventJitFPHomoLinear_taichi(100, 200, prob, weight, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPUniformLinear_taichi(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.EventJitFPUniformLinear_taichi(100, 200, prob, w_low, w_high, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPNormalLinear_taichi(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.EventJitFPNormalLinear_taichi(100, 200, prob, w_mu, w_sigma, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + if __name__ == '__main__': absltest.main() diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index 762c3c282..44a51b9da 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -14,4 +14,12 @@ EventJitFPHomoLinear as EventJitFPHomoLinear, EventJitFPNormalLinear as EventJitFPNormalLinear, EventJitFPUniformLinear as EventJitFPUniformLinear, + CSRLinear_taichi as CSRLinear_taichi, + EventCSRLinear_taichi as EventCSRLinear_taichi, + JitFPHomoLinear_taichi as JitFPHomoLinear_taichi, + JitFPUniformLinear_taichi as JitFPUniformLinear_taichi, + JitFPNormalLinear_taichi as JitFPNormalLinear_taichi, + EventJitFPHomoLinear_taichi as EventJitFPHomoLinear_taichi, + EventJitFPNormalLinear_taichi as EventJitFPNormalLinear_taichi, + EventJitFPUniformLinear_taichi as EventJitFPUniformLinear_taichi, ) From 65305b20368df61fbd74d4456949620753014d2c Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 22 Jan 2024 19:17:31 +0800 Subject: [PATCH 02/27] [math] Remove multiple results of event csrmv and csrmv --- brainpy/_src/math/event/_csr_matvec_taichi.py | 76 +++--- .../event/tests/test_event_csrmv_taichi.py | 34 +-- brainpy/_src/math/sparse/_csr_mv_taichi.py | 60 +++-- .../math/sparse/tests/test_csrmv_taichi.py | 243 +----------------- 4 files changed, 102 insertions(+), 311 deletions(-) diff --git a/brainpy/_src/math/event/_csr_matvec_taichi.py b/brainpy/_src/math/event/_csr_matvec_taichi.py index 9be9c49d9..2ee47d838 100644 --- a/brainpy/_src/math/event/_csr_matvec_taichi.py +++ b/brainpy/_src/math/event/_csr_matvec_taichi.py @@ -10,7 +10,7 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._csr_mv_taichi import csrmv_taichi as normal_csrmv_taichi +from brainpy._src.math.sparse._csr_mv_taichi import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo ti = import_taichi() @@ -333,13 +333,53 @@ def _event_csr_matvec_transpose( ct_values = ad.Zero(values) else: if values.aval.shape[0] == 1: # scalar - ct_values = csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] ct_values = jnp.inner(ct[0], ct_values) else: # heterogeneous values row, col = csr_to_coo(indices, indptr) ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] return ct_values, indices, indptr, events +def raw_csrmv_taichi( + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False +): + if transpose: + if events.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csrmv_transpose_bool_homo_p + else: + prim = _event_csrmv_transpose_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csrmv_transpose_homo_p + else: + prim = _event_csrmv_transpose_heter_p + else: + if events.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csrmv_bool_homo_p + else: + prim = _event_csrmv_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csrmv_homo_p + else: + prim = _event_csrmv_heter_p + + # computing + return prim(data, + indices, + indptr, + events, + outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], + transpose=transpose, + shape=shape) def csrmv_taichi( data: Union[float, jax.Array], @@ -419,37 +459,7 @@ def csrmv_taichi( if indices.shape[0] == 0: return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - if transpose: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_bool_homo_p - else: - prim = _event_csrmv_transpose_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_homo_p - else: - prim = _event_csrmv_transpose_heter_p - else: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_bool_homo_p - else: - prim = _event_csrmv_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csrmv_homo_p - else: - prim = _event_csrmv_heter_p - - # computing - return prim(data, - indices, - indptr, - events, - outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], - transpose=transpose, - shape=shape) + return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] def _define_op(cpu_kernel, gpu_kernel): diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py index b759a4789..c81aee7c0 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py @@ -20,14 +20,6 @@ def func(*args, **kwargs): return func -def sum_op2(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - - class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs) @@ -53,7 +45,7 @@ def test_homo(self, transpose, shape, homo_data): r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) r2 = bm.event.csrmv_taichi(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - assert (bm.allclose(r1, r2[0])) + assert (bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -78,7 +70,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)[0])) + self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, @@ -86,7 +78,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): f4 = jax.vmap(partial(bm.event.csrmv_taichi, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)[0])) + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) @@ -95,7 +87,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2)[0])) + f6(vmap_data1, vmap_data2))) bm.clear_buffer_memory() @@ -120,14 +112,14 @@ def test_homo_grad(self, shape, transpose, homo_data): # grad 'data' r1 = jax.grad(sum_op(bm.event.csrmv))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))( + r2 = jax.grad(sum_op(bm.event.csrmv_taichi))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -154,7 +146,7 @@ def test_heter(self, shape, transpose): r2 = bm.event.csrmv_taichi(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - assert (bm.allclose(r1, r2[0])) + assert (bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -180,7 +172,7 @@ def test_heter_vmap(self, shape, transpose): f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)[0])) + self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' data = bm.as_jax(rng.random(indices.shape)) @@ -189,7 +181,7 @@ def test_heter_vmap(self, shape, transpose): f4 = jax.vmap(partial(bm.event.csrmv_taichi, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)[0])) + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, @@ -199,7 +191,7 @@ def test_heter_vmap(self, shape, transpose): vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2)[0])) + f6(vmap_data1, vmap_data2))) bm.clear_buffer_memory() @@ -225,20 +217,20 @@ def test_heter_grad(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) r1 = jax.grad(sum_op(bm.event.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))( + r2 = jax.grad(sum_op(bm.event.csrmv_taichi))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) diff --git a/brainpy/_src/math/sparse/_csr_mv_taichi.py b/brainpy/_src/math/sparse/_csr_mv_taichi.py index cd09af08e..5038e372e 100644 --- a/brainpy/_src/math/sparse/_csr_mv_taichi.py +++ b/brainpy/_src/math/sparse/_csr_mv_taichi.py @@ -155,11 +155,11 @@ def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) + return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) + return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) def _sparse_csr_matvec_transpose( @@ -168,7 +168,7 @@ def _sparse_csr_matvec_transpose( if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): raise ValueError("Cannot transpose with respect to sparse indices.") if ad.is_undefined_primal(vector): - ct_vector = csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) else: @@ -176,7 +176,7 @@ def _sparse_csr_matvec_transpose( ct_data = ad.Zero(data) else: if data.aval.shape[0] == 1: # scalar - ct_data = csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] + ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] ct_data = jnp.inner(ct[0], ct_data) else: row, col = csr_to_coo(indices, indptr) @@ -184,6 +184,35 @@ def _sparse_csr_matvec_transpose( return ct_data, indices, indptr, vector +def raw_csrmv_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + out_shape = shape[1] if transpose else shape[0] + if transpose: + if data.shape[0] == 1: + prim = _csr_matvec_transpose_homo_p + else: + prim = _csr_matvec_transpose_heter_p + else: + if data.shape[0] == 1: + prim = _csr_matvec_homo_p + else: + prim = _csr_matvec_heter_p + + return prim(data, + indices, + indptr, + vector, + outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], + transpose=transpose, + shape=shape) + def csrmv_taichi( data: Union[float, jnp.ndarray, Array], @@ -242,26 +271,9 @@ def csrmv_taichi( raise ValueError('indices should be a 1D vector with integer type.') if not jnp.issubdtype(indptr.dtype, jnp.integer): raise ValueError('indptr should be a 1D vector with integer type.') - out_shape = shape[1] if transpose else shape[0] - - if transpose: - if data.shape[0] == 1: - prim = _csr_matvec_transpose_homo_p - else: - prim = _csr_matvec_transpose_heter_p - else: - if data.shape[0] == 1: - prim = _csr_matvec_homo_p - else: - prim = _csr_matvec_heter_p - - return prim(data, - indices, - indptr, - vector, - outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], - transpose=transpose, - shape=shape) + + return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] + def _define_op(cpu_kernel, gpu_kernel): diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py index 2b3d7b5b0..fed665c8d 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py @@ -21,13 +21,6 @@ def func(*args, **kwargs): return func -def sum_op2(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - def compare_with_nan_tolerance(a, b, tol=1e-8): """ @@ -62,219 +55,6 @@ def compare_with_nan_tolerance(a, b, tol=1e-8): vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') -### MANUAL TESTS ### -# transposes = [True, False] -# homo_datas = [-1., 0., 0.1, 1.] -# shapes = [(100, 200), (10, 1000), (2, 2000)] -# -# -# def test_homo(transpose, shape, homo_data): -# print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# conn = bp.conn.FixedProb(0.1) -# -# # matrix -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# # vector -# rng = bm.random.RandomState(123) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = bm.sparse.csrmv_taichi(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# assert (bm.allclose(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_homo_vmap(transpose, shape, homo_data): -# print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# heter_data = bm.ones((10, indices.shape[0])).value * homo_data -# homo_data = bm.ones(10).value * homo_data -# dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) -# -# f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# r1 = jax.vmap(f1)(homo_data) -# r2 = jax.vmap(f1)(homo_data) -# assert (bm.allclose(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_homo_grad(transpose, shape, homo_data): -# print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, -# indices, -# indptr, -# shape=shape) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# # print('grad data start') -# # grad 'data' -# r1 = jax.grad(sum_op(vector_csr_matvec))( -# homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( -# homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# -# # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, -# # shape=shape, transpose=transpose).sum(), -# # argnums=0) -# # csr_f2 = jax.grad(lambda a: bm.sparse.csrmv_taichi(a, indices, indptr, vector, -# # shape=shape, transpose=transpose)[0].sum(), -# # argnums=0) -# # r1 = csr_f1(homo_data) -# # r2 = csr_f2(homo_data) -# assert (bm.allclose(r1, r2)) -# -# # print('grad vector start') -# # grad 'vector' -# r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( -# homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( -# homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# # csr_f3 = jax.grad(lambda v: vector_csr_matvec(homo_data, indices, indptr, v, -# # shape=shape, transpose=transpose).sum()) -# # csr_f4 = jax.grad(lambda v: bm.sparse.csrmv_taichi(homo_data, indices, indptr, v, -# # shape=shape, transpose=transpose)[0].sum()) -# # r3 = csr_f3(vector) -# # r4 = csr_f4(vector) -# assert (bm.allclose(r3, r4)) -# -# # csr_f5 = jax.grad(lambda a, v: vector_csr_matvec(a, indices, indptr, v, -# # shape=shape, transpose=transpose).sum(), -# # argnums=(0, 1)) -# # csr_f6 = jax.grad(lambda a, v: bm.sparse.csrmv_taichi(a, indices, indptr, v, -# # shape=shape, transpose=transpose)[0].sum(), -# # argnums=(0, 1)) -# # r5 = csr_f5(homo_data, vector) -# # r6 = csr_f6(homo_data, vector) -# # assert(bm.allclose(r5[0], r6[0])) -# # assert(bm.allclose(r5[1], r6[1])) -# -# bm.clear_buffer_memory() -# -# -# def test_heter(transpose, shape): -# print(f'test_heter: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# heter_data = bm.as_jax(rng.random(indices.shape)) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) -# r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape) -# # bm.nan_to_num(r1) -# # bm.nan_to_num(r2[0]) -# # print(r1) -# # print(r1 - r2[0]) -# assert (compare_with_nan_tolerance(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_heter_vmap(transpose, shape): -# print(f'test_heter_vmap: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# heter_data = rng.random((10, indices.shape[0])) -# heter_data = bm.as_jax(heter_data) -# dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, -# shape=shape))(heter_data) -# -# f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# r1 = jax.vmap(f1)(heter_data) -# r2 = jax.vmap(f2)(heter_data) -# assert (bm.allclose(r1, r2[0])) -# -# -# def test_heter_grad(transpose, shape): -# print(f'test_heter_grad: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# heter_data = rng.random(indices.shape) -# heter_data = bm.as_jax(heter_data) -# dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# # grad 'data' -# r1 = jax.grad(sum_op(vector_csr_matvec))( -# heter_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( -# heter_data, indices, indptr, vector, shape=shape, transpose=transpose) -# assert (bm.allclose(r1, r2)) -# -# # grad 'vector' -# r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# assert (bm.allclose(r3, r4)) -# -# r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# assert (bm.allclose(r5[0], r6[0])) -# assert (bm.allclose(r5[1], r6[1])) -# -# bm.clear_buffer_memory() -# -# def test_all(): -# # for transpose in transposes: -# # for shape in shapes: -# # for homo_data in homo_datas: -# # test_homo(transpose, shape, homo_data) -# # test_homo_vmap(transpose, shape, homo_data) -# # test_homo_grad(transpose, shape, homo_data) -# -# for transpose in transposes: -# for shape in shapes: -# test_heter(transpose, shape) -# test_heter_vmap(transpose, shape) -# test_heter_grad(transpose, shape) -# test_all() - -# PYTEST class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_csrmv_taichi, self).__init__(*args, **kwargs) @@ -302,7 +82,7 @@ def test_homo(self, transpose, shape, homo_data): r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) r2 = bm.sparse.csrmv_taichi(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2[0])) + self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -332,7 +112,7 @@ def test_homo_vmap(self, transpose, shape, v): shape=shape, transpose=transpose) r1 = jax.vmap(f1)(homo_data) r2 = jax.vmap(f1)(homo_data) - self.assertTrue(bm.allclose(r1, r2[0])) + self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -360,7 +140,7 @@ def test_homo_grad(self, transpose, shape, homo_data): # grad 'data' r1 = jax.grad(sum_op(vector_csr_matvec))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( + r2 = jax.grad(sum_op(bm.sparse.csrmv_taichi))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, @@ -377,14 +157,14 @@ def test_homo_grad(self, transpose, shape, homo_data): # grad 'vector' r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) @@ -413,10 +193,7 @@ def test_heter(self, transpose, shape): r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape) - print(r1) - print(r2[0]) - - self.assertTrue(compare_with_nan_tolerance(r1, r2[0])) + self.assertTrue(compare_with_nan_tolerance(r1, r2)) bm.clear_buffer_memory() @@ -445,7 +222,7 @@ def test_heter_vmap(self, transpose, shape): shape=shape, transpose=transpose) r1 = jax.vmap(f1)(heter_data) r2 = jax.vmap(f2)(heter_data) - self.assertTrue(compare_with_nan_tolerance(r1, r2[0])) + self.assertTrue(compare_with_nan_tolerance(r1, r2)) @parameterized.product( transpose=[True, False], @@ -467,20 +244,20 @@ def test_heter_grad(self, transpose, shape): # grad 'data' r1 = jax.grad(sum_op(vector_csr_matvec))( heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( + r2 = jax.grad(sum_op(bm.sparse.csrmv_taichi))( heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=3)( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=(0, 3))( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) From 8710ad8abb88f4c20a9abf710c4c0a58e4f68297 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 22 Jan 2024 21:27:35 +0800 Subject: [PATCH 03/27] [dnn] Fix bugs --- brainpy/_src/dnn/linear.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index b837dd920..36327702a 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1378,7 +1378,7 @@ def update(self, x): if x.ndim == 1: return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose)[0] + transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) @@ -1390,7 +1390,7 @@ def update(self, x): def _batch_csrmv(self, x): return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose)[0] + transpose=self.transpose) class EventCSRLinear_taichi(_CSRLayer_taichi): r"""Synaptic matrix multiplication with event CSR sparse computation(taichi customized operator). @@ -1427,7 +1427,7 @@ def update(self, x): if x.ndim == 1: return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose)[0] + transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) @@ -1439,7 +1439,7 @@ def update(self, x): def _batch_csrmv(self, x): return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose)[0] + transpose=self.transpose) @ti.kernel def _cpu_csr_on_pre_update_taichi(w: ti.types.ndarray(ndim=1), From 61c2a079a477b4c1c253544a81c09356d700f69d Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 24 Jan 2024 17:32:52 +0800 Subject: [PATCH 04/27] [dnn] Update jitconn event atomic=True --- brainpy/_src/dnn/linear.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 36327702a..559a3ed95 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1086,7 +1086,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, ): super().__init__(name=name, mode=mode) @@ -1167,7 +1167,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, ): super().__init__(name=name, mode=mode) @@ -1245,7 +1245,7 @@ def __init__( seed: Optional[int] = None, sharding: Optional[Sharding] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): @@ -1788,7 +1788,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, ): super().__init__(name=name, mode=mode) @@ -1869,7 +1869,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, ): super().__init__(name=name, mode=mode) @@ -1947,7 +1947,7 @@ def __init__( seed: Optional[int] = None, sharding: Optional[Sharding] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): From f1ac501ef6419e8935c003c87f5979d169b2cb84 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 25 Jan 2024 17:31:51 +0800 Subject: [PATCH 05/27] [dnn] Replace brainpylib opeartors with taichi customized operators --- brainpy/_src/dnn/linear.py | 810 +++---------------------------------- brainpy/dnn/linear.py | 8 - 2 files changed, 59 insertions(+), 759 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 559a3ed95..d3de95cbf 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -30,9 +30,6 @@ 'CSRLinear', 'EventCSRLinear', 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear', 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear', - 'CSRLinear_taichi', 'EventCSRLinear_taichi', - 'JitFPHomoLinear_taichi', 'JitFPUniformLinear_taichi', 'JitFPNormalLinear_taichi', - 'EventJitFPHomoLinear_taichi', 'EventJitFPNormalLinear_taichi', 'EventJitFPUniformLinear_taichi', ] @@ -574,18 +571,15 @@ def __init__( sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, - method: str = 'cusparse', transpose: bool = True, ): super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) - self.method = method def update(self, x): if x.ndim == 1: - return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, + return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose, - method=self.method) + transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) @@ -595,11 +589,9 @@ def update(self, x): raise ValueError def _batch_csrmv(self, x): - return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, + return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose, - method=self.method) - + transpose=self.transpose) class EventCSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with event CSR sparse computation. @@ -634,7 +626,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) elif x.ndim > 1: @@ -646,50 +638,64 @@ def update(self, x): raise ValueError def _batch_csrmv(self, x): - return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) - -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # pre id +@ti.kernel +def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + w_value = w[0] + out_w[:] = w_value + w_min_value = w_min[0] + w_max_value = w_max[0] + for i in range(spike.shape[0]): # pre id if spike[i]: for k in range(indptr[i], indptr[i + 1]): # synapse id j = indices[k] # post id - # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) - out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) - - + out_w[k] = ti.min(ti.max(out_w[k] + trace[j], w_min_value), w_max_value) csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update) - def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf + w = jax.Array(w) + w_min = jax.Array(w_min) + w_max = jax.Array(w_max) return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] +@ti.kernel +def _cpu_csc_on_pre_update(w: ti.types.ndarray(ndim=1), + post_ids: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + w_ids: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + w_value = w[0] + out_w[:] = w_value + w_min_value = w_min[0] + w_max_value = w_max[0] -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # post id + for i in range(spike.shape[0]): # post id if spike[i]: for k in range(indptr[i], indptr[i + 1]): j = post_ids[k] # pre id l = w_ids[k] # syn id - out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) - + out_w[l] = ti.min(ti.max(out_w[l] + trace[j], w_min_value), w_max_value) csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) @@ -699,10 +705,14 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m w_min = -np.inf if w_max is None: w_max = np.inf + w = jax.Array(w) + w_min = jax.Array(w_min) + w_max = jax.Array(w_max) return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + class CSCLinear(Layer): r"""Synaptic matrix multiplication with CSC sparse computation. @@ -809,7 +819,7 @@ def __init__( class JitFPHomoLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). It performs the computation of: @@ -866,7 +876,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, + return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -881,14 +891,14 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, + return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) class JitFPUniformLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). It performs the computation of: @@ -946,7 +956,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -961,14 +971,14 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) class JitFPNormalLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). It performs the computation of: @@ -1026,7 +1036,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1041,14 +1051,14 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) class EventJitFPHomoLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). It performs the computation of: @@ -1105,7 +1115,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, + return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1120,14 +1130,14 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, + return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) class EventJitFPUniformLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). It performs the computation of: @@ -1185,7 +1195,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1200,715 +1210,13 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) class EventJitFPNormalLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_mu: float. The center of the normal distribution. - w_sigma: float. The standard variance of the normal distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_mu: float, - w_sigma: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - transpose: bool = False, - atomic: bool = True, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_mu = w_mu - self.w_sigma = w_sigma - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -### TAICHI CUSTOMIZED OPERATOR IMPLEMENTATION ### - -class _CSRLayer_taichi(Layer, SupportSTDP): - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - assert sharding is None, 'Currently this model does not support sharding.' - self.conn = conn - self.sharding = sharding - self.transpose = transpose - - # connection - self.indices, self.indptr = self.conn.require('csr') - - # weight - weight = init.parameter(weight, (self.indices.size,)) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if bm.isscalar(self.weight): - raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') - if self.weight.shape != self.indices.shape: - raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: # update on presynaptic spike - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value = csr_on_pre_update_taichi(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) - if on_post is not None: # update on postsynaptic spike - if not hasattr(self, '_pre_ids'): - with jax.ensure_compile_time_eval(): - self._pre_ids, self._post_indptr, self.w_indices = csr2csc( - [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) - ) - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value = csc_on_post_update_taichi(self.weight.value, self._pre_ids, self._post_indptr, - self.w_indices, spike, trace, w_min, w_max) - - -class CSRLinear_taichi(_CSRLayer_taichi): - r"""Synaptic matrix multiplication with CSR sparse computation(taichi customized operator). - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a CSR sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) - - def update(self, x): - if x.ndim == 1: - return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - elif x.ndim > 1: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_csrmv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_csrmv(self, x): - return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - -class EventCSRLinear_taichi(_CSRLayer_taichi): - r"""Synaptic matrix multiplication with event CSR sparse computation(taichi customized operator). - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weight using a CSR sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) - - def update(self, x): - if x.ndim == 1: - return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - elif x.ndim > 1: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_csrmv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_csrmv(self, x): - return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - -@ti.kernel -def _cpu_csr_on_pre_update_taichi(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - w_value = w[0] - out_w[:] = w_value - w_min_value = w_min[0] - w_max_value = w_max[0] - for i in range(spike.shape[0]): # pre id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): # synapse id - j = indices[k] # post id - out_w[k] = ti.min(ti.max(out_w[k] + trace[j], w_min_value), w_max_value) - -csr_on_pre_update_prim_taichi = bm.XLACustomOp(_cpu_csr_on_pre_update_taichi) - - -def csr_on_pre_update_taichi(w, indices, indptr, spike, trace, w_min=None, w_max=None): - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - w = jax.Array(w) - w_min = jax.Array(w_min) - w_max = jax.Array(w_max) - return csr_on_pre_update_prim_taichi(w, indices, indptr, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] - -@ti.kernel -def _cpu_csc_on_pre_update_taichi(w: ti.types.ndarray(ndim=1), - post_ids: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - w_ids: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - w_value = w[0] - out_w[:] = w_value - w_min_value = w_min[0] - w_max_value = w_max[0] - - for i in range(spike.shape[0]): # post id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = post_ids[k] # pre id - l = w_ids[k] # syn id - out_w[l] = ti.min(ti.max(out_w[l] + trace[j], w_min_value), w_max_value) - -csc_on_pre_update_prim_taichi = bm.XLACustomOp(_cpu_csc_on_pre_update_taichi) - - -def csc_on_post_update_taichi(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - w = jax.Array(w) - w_min = jax.Array(w_min) - w_max = jax.Array(w_max) - return csc_on_pre_update_prim_taichi(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] - -class JitFPHomoLinear_taichi(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is the same :math:`weight`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - weight: float. The synaptic value at each position. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - weight: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = False, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class JitFPUniformLinear_taichi(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_low: float. The lowest value of the uniform distribution. - w_high: float. The highest value of the uniform distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_low: float, - w_high: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = False, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_low = w_low - self.w_high = w_high - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class JitFPNormalLinear_taichi(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_mu: float. The center of the normal distribution. - w_sigma: float. The standard variance of the normal distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_mu: float, - w_sigma: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - transpose: bool = False, - atomic: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_mu = w_mu - self.w_sigma = w_sigma - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class EventJitFPHomoLinear_taichi(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is the same :math:`weight`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - weight: float. The synaptic value at each position. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - weight: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = True, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 1000000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class EventJitFPUniformLinear_taichi(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_low: float. The lowest value of the uniform distribution. - w_high: float. The highest value of the uniform distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_low: float, - w_high: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = True, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_low = w_low - self.w_high = w_high - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class EventJitFPNormalLinear_taichi(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). It performs the computation of: diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index 44a51b9da..762c3c282 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -14,12 +14,4 @@ EventJitFPHomoLinear as EventJitFPHomoLinear, EventJitFPNormalLinear as EventJitFPNormalLinear, EventJitFPUniformLinear as EventJitFPUniformLinear, - CSRLinear_taichi as CSRLinear_taichi, - EventCSRLinear_taichi as EventCSRLinear_taichi, - JitFPHomoLinear_taichi as JitFPHomoLinear_taichi, - JitFPUniformLinear_taichi as JitFPUniformLinear_taichi, - JitFPNormalLinear_taichi as JitFPNormalLinear_taichi, - EventJitFPHomoLinear_taichi as EventJitFPHomoLinear_taichi, - EventJitFPNormalLinear_taichi as EventJitFPNormalLinear_taichi, - EventJitFPUniformLinear_taichi as EventJitFPUniformLinear_taichi, ) From 7499c1e5b1a1c92cd428b6d93fef42a883081b31 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 25 Jan 2024 17:33:29 +0800 Subject: [PATCH 06/27] Update linear.py --- brainpy/_src/dnn/linear.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index d3de95cbf..8b85137f9 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -819,7 +819,7 @@ def __init__( class JitFPHomoLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -898,7 +898,7 @@ def _batch_mv(self, x): class JitFPUniformLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -978,7 +978,7 @@ def _batch_mv(self, x): class JitFPNormalLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1058,7 +1058,7 @@ def _batch_mv(self, x): class EventJitFPHomoLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1137,7 +1137,7 @@ def _batch_mv(self, x): class EventJitFPUniformLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1217,7 +1217,7 @@ def _batch_mv(self, x): class EventJitFPNormalLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: From be87e9faaab2a149ba0b7c3b5e5ed2141171f784 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 25 Jan 2024 17:34:21 +0800 Subject: [PATCH 07/27] Update test_linear.py --- brainpy/_src/dnn/tests/test_linear.py | 128 -------------------------- 1 file changed, 128 deletions(-) diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 48fc60a01..7fc89526c 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -213,133 +213,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): self.assertTrue(y2.shape == shape + (200,)) bm.clear_buffer_memory() - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_CSRLinear_taichi(self, conn): - bm.random.seed() - f = bp.dnn.CSRLinear_taichi(conn, weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_EventCSRLinear_taichi(self,conn): - bm.random.seed() - f=bp.layers.EventCSRLinear_taichi(conn,weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPHomoLinear_taichi(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.JitFPHomoLinear_taichi(100, 200, prob, weight, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPUniformLinear_taichi(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.JitFPUniformLinear_taichi(100, 200, prob, w_low, w_high, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPNormalLinear_taichi(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.JitFPNormalLinear_taichi(100, 200, prob, w_mu, w_sigma, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPHomoLinear_taichi(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.EventJitFPHomoLinear_taichi(100, 200, prob, weight, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPUniformLinear_taichi(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.EventJitFPUniformLinear_taichi(100, 200, prob, w_low, w_high, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPNormalLinear_taichi(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.EventJitFPNormalLinear_taichi(100, 200, prob, w_mu, w_sigma, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - if __name__ == '__main__': absltest.main() From ae95fad9fb2a04063e3b09adf9fcbb2e52a6322a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 25 Jan 2024 19:52:16 +0800 Subject: [PATCH 08/27] [dnn, math] Fix bugs --- brainpy/_src/dnn/linear.py | 54 ++++++------------- .../_src/math/op_register/taichi_aot_based.py | 2 +- 2 files changed, 17 insertions(+), 39 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 8b85137f9..de2106010 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -642,24 +642,17 @@ def _batch_csrmv(self, x): shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) -@ti.kernel -def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - w_value = w[0] - out_w[:] = w_value - w_min_value = w_min[0] - w_max_value = w_max[0] - for i in range(spike.shape[0]): # pre id +@numba.njit(nogil=True, fastmath=True, parallel=False) +def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): + out_w[:] = w + w_min = w_min[()] + w_max = w_max[()] + for i in numba.prange(spike.shape[0]): # pre id if spike[i]: for k in range(indptr[i], indptr[i + 1]): # synapse id j = indices[k] # post id - out_w[k] = ti.min(ti.max(out_w[k] + trace[j], w_min_value), w_max_value) + # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) + out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update) @@ -669,33 +662,21 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): w_min = -np.inf if w_max is None: w_max = np.inf - w = jax.Array(w) - w_min = jax.Array(w_min) - w_max = jax.Array(w_max) return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] -@ti.kernel -def _cpu_csc_on_pre_update(w: ti.types.ndarray(ndim=1), - post_ids: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - w_ids: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - w_value = w[0] - out_w[:] = w_value - w_min_value = w_min[0] - w_max_value = w_max[0] - - for i in range(spike.shape[0]): # post id +@numba.njit(nogil=True, fastmath=True, parallel=False) +def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + out_w[:] = w + w_min = w_min[()] + w_max = w_max[()] + for i in numba.prange(spike.shape[0]): # post id if spike[i]: for k in range(indptr[i], indptr[i + 1]): j = post_ids[k] # pre id l = w_ids[k] # syn id - out_w[l] = ti.min(ti.max(out_w[l] + trace[j], w_min_value), w_max_value) + out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) + csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) @@ -705,9 +686,6 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m w_min = -np.inf if w_max is None: w_max = np.inf - w = jax.Array(w) - w_min = jax.Array(w_min) - w_max = jax.Array(w_max) return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 878b205cf..6f28c4a35 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -361,7 +361,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): try: _build_kernel(source_md5_encode, kernel, ins_dict, outs_dict, platform) except Exception as e: - os.removedirs(os.path.join(kernels_aot_path, source_md5_encode)) + os.removedirs(os.path.join(kernels_aot_path, kernel.__name__, source_md5_encode)) raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e # returns From b993e135b711014b3155d7b3842fd4b532ed4806 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 25 Jan 2024 22:28:05 +0800 Subject: [PATCH 09/27] [math] Fix bugs --- brainpy/_src/math/op_register/taichi_aot_based.py | 7 +++++-- brainpy/_src/math/sparse/_csr_mv_taichi.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 6f28c4a35..96ebabfa7 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -347,7 +347,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): # kernel to code codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform) - source_md5_encode = kernel.__name__ + '/' + encode_md5(codes) + source_md5_encode = os.path.join(kernel.__name__, encode_md5(codes)) # create ins, outs dict from kernel's args in_num = len(ins) @@ -361,7 +361,10 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): try: _build_kernel(source_md5_encode, kernel, ins_dict, outs_dict, platform) except Exception as e: - os.removedirs(os.path.join(kernels_aot_path, kernel.__name__, source_md5_encode)) + try: + os.removedirs(os.path.join(kernels_aot_path, source_md5_encode)) + except Exception: + raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e # returns diff --git a/brainpy/_src/math/sparse/_csr_mv_taichi.py b/brainpy/_src/math/sparse/_csr_mv_taichi.py index 5038e372e..84cae5554 100644 --- a/brainpy/_src/math/sparse/_csr_mv_taichi.py +++ b/brainpy/_src/math/sparse/_csr_mv_taichi.py @@ -271,6 +271,10 @@ def csrmv_taichi( raise ValueError('indices should be a 1D vector with integer type.') if not jnp.issubdtype(indptr.dtype, jnp.integer): raise ValueError('indptr should be a 1D vector with integer type.') + + # if the shape of indices is (0,), then we return a zero vector + if indices.shape[0] == 0: + return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] From 420cbba7b9c37516eced60ed9f571b9cc5359006 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 26 Jan 2024 11:56:37 +0800 Subject: [PATCH 10/27] Update linear.py --- brainpy/_src/dnn/linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index de2106010..c3f287f64 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -4,11 +4,13 @@ import numbers from typing import Dict, Optional, Union, Callable +from brainpy._src.dependency_check import import_taichi import jax import jax.numpy as jnp import numba import numpy as np -import taichi as ti + +ti = import_taichi() from brainpy import math as bm from brainpy._src import connect, initialize as init From 33f21b92095fbafe4f8a52bfba123ee394fe4a3d Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 28 Jan 2024 19:14:09 +0800 Subject: [PATCH 11/27] Refactor operators --- brainpy/_src/math/event/__init__.py | 1 - brainpy/_src/math/event/_csr_matvec.py | 555 ++++++- brainpy/_src/math/event/_csr_matvec_taichi.py | 497 ------ .../_src/math/event/tests/test_event_csrmv.py | 34 +- .../event/tests/test_event_csrmv_taichi.py | 54 +- brainpy/_src/math/jitconn/__init__.py | 4 +- brainpy/_src/math/jitconn/_event_matvec.py | 1371 ++++++++++++++++- .../_src/math/jitconn/_event_matvec_taichi.py | 1277 --------------- brainpy/_src/math/jitconn/_matvec.py | 1134 +++++++++++++- brainpy/_src/math/jitconn/_matvec_taichi.py | 911 ----------- .../math/jitconn/tests/test_event_matvec.py | 37 +- .../jitconn/tests/test_event_matvec_taichi.py | 38 +- .../_src/math/jitconn/tests/test_matvec.py | 37 +- .../math/jitconn/tests/test_matvec_taichi.py | 38 +- brainpy/_src/math/sparse/__init__.py | 1 - brainpy/_src/math/sparse/_csr_mv.py | 354 ++++- brainpy/_src/math/sparse/_csr_mv_taichi.py | 304 ---- .../math/sparse/tests/test_csrmv_taichi.py | 24 +- brainpy/math/event.py | 1 - brainpy/math/jitconn.py | 8 - brainpy/math/sparse.py | 1 - 21 files changed, 3543 insertions(+), 3138 deletions(-) delete mode 100644 brainpy/_src/math/event/_csr_matvec_taichi.py delete mode 100644 brainpy/_src/math/jitconn/_event_matvec_taichi.py delete mode 100644 brainpy/_src/math/jitconn/_matvec_taichi.py delete mode 100644 brainpy/_src/math/sparse/_csr_mv_taichi.py diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 865d682a0..631129558 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,5 +1,4 @@ from ._info_collection import * from ._csr_matvec import * -from ._csr_matvec_taichi import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 9da0cf524..53e110dd4 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -14,6 +14,7 @@ from functools import partial from typing import Union, Tuple +import brainpy.math as bm import jax import jax.numpy as jnp import numba @@ -22,10 +23,13 @@ from jax.interpreters import ad, xla from jax.lib import xla_client +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching) + register_general_batching, + XLACustomOp) from brainpy._src.math.sparse._csr_mv import csrmv as normal_csrmv +from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) from brainpy.errors import GPUOperatorNotFound @@ -34,8 +38,75 @@ 'csrmv' ] +ti = import_taichi() + def csrmv( + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + method: str = None +) -> jax.Array: + """Product of a sparse CSR matrix and a dense event vector. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + events: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + If ``transpose=True``, the operator will compute based on the + event-driven property of the ``events`` vector. + method: str + The method used to compute Matrix-Vector Multiplication. For cpu platform, + the default method is ``taichi``. For gpu platform, the default method is + ``brainpylib``. + The candidate methods are: + + - ``taichi``: using Taichi kernel. + - ``brainpylib``: using brainpylib operators. + + Returns + ------- + y : Array + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + + if method is None: + if bm.get_platform() == 'cpu': + method = 'taichi' + elif bm.get_platform() == 'gpu': + method = 'brainpylib' + + if method == 'taichi': + return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) + elif method == 'brainpylib': + return csrmv_brainpylib(data, indices, indptr, events, shape=shape, transpose=transpose) + else: + raise ValueError(f'Unknown method {method}.') + + +### BRAINPYLIB ### + +def csrmv_brainpylib( data: Union[float, jax.Array], indices: jax.Array, indptr: jax.Array, @@ -109,7 +180,6 @@ def csrmv( # computing return event_csr_matvec_p.bind(data, indices, indptr, events, shape=shape, transpose=transpose) - # ---------------------------------------------------------- # event csr matvec # ---------------------------------------------------------- @@ -555,3 +625,484 @@ def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, t ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose register_general_batching(event_csr_matvec_p) # batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule + + +### TAICHI ### + +def csrmv_taichi( + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False +) -> jax.Array: + """Product of a sparse CSR matrix and a dense event vector. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + events: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + If ``transpose=True``, the operator will compute based on the + event-driven property of the ``events`` vector. + + Returns + ------- + y : Array + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + data = as_jax(data) + indices = as_jax(indices) + indptr = as_jax(indptr) + events = as_jax(events) + + # checking + data = jnp.atleast_1d(data) + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + else: + raise ValueError('data should be a scalar or 1D vector. ' + f'But we got {np.ndim(data)}-D array.') + if np.ndim(indices) != 1: + raise ValueError('indices should be a 1D vector with integer type.') + if np.ndim(indptr) != 1: + raise ValueError('indptr should be a 1D vector with integer type.') + if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: + raise ValueError( + 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') + if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: + raise ValueError( + 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') + if np.ndim(events) != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + # if the shape of indices is (0,), then we return a zero vector + if indices.shape[0] == 0: + return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) + + return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] + +# ------------- +# CPU operators +# ------------- + +# 1. The benchmarking shows that the performance of the following transpose +# kernels is maximized when using serialized mode +# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable +# arguments, we have to define each kernel separately when the +# non-differentiable/non-jittable arguments are different. + + +@ti.kernel +def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + +@ti.kernel +def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + +@ti.kernel +def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + +@ti.kernel +def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + +@ti.kernel +def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += value + out[row_i] = r + + +@ti.kernel +def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += values[j] + out[row_i] = r + + +@ti.kernel +def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += value + out[row_i] = r + + +@ti.kernel +def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += values[j] + out[row_i] = r + + +# ------------- +# GPU operators +# ------------- + +# 1. GPU kernels are different from the CPU ones, since the GPU kernels need +# to use warp-level parallelism to achieve the best performance. + + +@ti.kernel +def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + +@ti.kernel +def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + +# TODO +# It is important to note that the following warp-based kernels +# should be improved, since the atomic_add for each thread is not +# very efficient. Instead, the warp-level reduction primitive +# should be used. +# see ``warp_reduce_sum()`` function in tifunc.py. +# However, currently Taichi does not support general warp-level primitives. + + +@ti.kernel +def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +@ti.kernel +def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +@ti.kernel +def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + +@ti.kernel +def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + +@ti.kernel +def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +@ti.kernel +def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive + +def raw_csrmv_taichi( + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False +): + if transpose: + if events.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csrmv_transpose_bool_homo_p + else: + prim = _event_csrmv_transpose_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csrmv_transpose_homo_p + else: + prim = _event_csrmv_transpose_heter_p + else: + if events.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csrmv_bool_homo_p + else: + prim = _event_csrmv_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csrmv_homo_p + else: + prim = _event_csrmv_heter_p + + # computing + return prim(data, + indices, + indptr, + events, + outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], + transpose=transpose, + shape=shape) + + +def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matvec_jvp_values, None, None, _event_csr_matvec_jvp_events) + prim.def_transpose_rule(_event_csr_matvec_transpose) + return prim + + +# transpose bool homo +_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, + _event_csr_matvec_transpose_bool_homo_gpu) + +# transpose homo +_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu) + +# not transpose bool homo +_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu) + +# not transpose homo +_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu) + +# transpose bool heter +_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, + _event_csr_matvec_transpose_bool_heter_gpu) + +# transpose heter +_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, + _event_csr_matvec_transpose_heter_gpu) + +# not transpose bool heter +_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu) + +# not transpose heter +_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) + + + + +def _event_csr_matvec_jvp_values(val_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) + + +def _event_csr_matvec_jvp_events(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + + +def _event_csr_matvec_transpose( + ct, values, indices, indptr, events, *, outs, transpose, shape +): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(events): + ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] + return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) + else: + if type(ct[0]) is ad.Zero: + ct_values = ad.Zero(values) + else: + if values.aval.shape[0] == 1: # scalar + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = jnp.inner(ct[0], ct_values) + else: # heterogeneous values + row, col = csr_to_coo(indices, indptr) + ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] + return ct_values, indices, indptr, events \ No newline at end of file diff --git a/brainpy/_src/math/event/_csr_matvec_taichi.py b/brainpy/_src/math/event/_csr_matvec_taichi.py deleted file mode 100644 index 2ee47d838..000000000 --- a/brainpy/_src/math/event/_csr_matvec_taichi.py +++ /dev/null @@ -1,497 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Union, Tuple - -import jax -import jax.numpy as jnp -import numpy as np -from jax.interpreters import ad - -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._csr_mv_taichi import raw_csrmv_taichi as normal_csrmv_taichi -from brainpy._src.math.sparse._utils import csr_to_coo - -ti = import_taichi() - -__all__ = [ - 'csrmv_taichi' -] - - -# ------------- -# CPU operators -# ------------- - -# 1. The benchmarking shows that the performance of the following transpose -# kernels is maximized when using serialized mode -# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable -# arguments, we have to define each kernel separately when the -# non-differentiable/non-jittable arguments are different. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += values[j] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - -# 1. GPU kernels are different from the CPU ones, since the GPU kernels need -# to use warp-level parallelism to achieve the best performance. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -# TODO -# It is important to note that the following warp-based kernels -# should be improved, since the atomic_add for each thread is not -# very efficient. Instead, the warp-level reduction primitive -# should be used. -# see ``warp_reduce_sum()`` function in tifunc.py. -# However, currently Taichi does not support general warp-level primitives. - - -@ti.kernel -def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -def _event_csr_matvec_jvp_values(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) - - -def _event_csr_matvec_jvp_events(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) - - -def _event_csr_matvec_transpose( - ct, values, indices, indptr, events, *, outs, transpose, shape -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events - -def raw_csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -): - if transpose: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_bool_homo_p - else: - prim = _event_csrmv_transpose_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_homo_p - else: - prim = _event_csrmv_transpose_heter_p - else: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_bool_homo_p - else: - prim = _event_csrmv_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csrmv_homo_p - else: - prim = _event_csrmv_heter_p - - # computing - return prim(data, - indices, - indptr, - events, - outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], - transpose=transpose, - shape=shape) - -def csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError( - 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError( - 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # if the shape of indices is (0,), then we return a zero vector - if indices.shape[0] == 0: - return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - - return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] - - -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values, None, None, _event_csr_matvec_jvp_events) - prim.def_transpose_rule(_event_csr_matvec_transpose) - return prim - - -# transpose bool homo -_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, - _event_csr_matvec_transpose_bool_homo_gpu) - -# transpose homo -_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu) - -# not transpose bool homo -_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu) - -# not transpose homo -_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu) - -# transpose bool heter -_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, - _event_csr_matvec_transpose_bool_heter_gpu) - -# transpose heter -_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, - _event_csr_matvec_transpose_heter_gpu) - -# not transpose bool heter -_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu) - -# not transpose heter -_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 3ca456b0b..55253ab47 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -16,6 +16,8 @@ if platform.system() == 'Windows' and not is_manual_test: pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) +brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') +taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') def sum_op(op): def func(*args, **kwargs): @@ -56,18 +58,18 @@ def test_homo(self, shape, transpose, homo_data): events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data - r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = bm.event.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) + r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) - r3 = bm.event.csrmv(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r3 = brainpylib_csr_matvec(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r3)) dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r4 = (events @ dense) if transpose else (dense @ events) self.assertTrue(bm.allclose(r1, r4)) - r5 = bm.event.csrmv(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r5 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r5)) bm.clear_buffer_memory() @@ -97,7 +99,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) f2 = jax.vmap( partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), @@ -106,7 +108,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' - f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, + f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, shape=shape, transpose=transpose)) f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr, shape=shape, transpose=transpose)) @@ -114,7 +116,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose, method='cusparse')) vmap_data1 = bm.as_jax([homo_data] * 10) @@ -152,7 +154,7 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(bm.event.csrmv))( + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) @@ -162,7 +164,7 @@ def test_homo_grad(self, shape, transpose, homo_data): self.assertTrue(bm.allclose(r1, r3)) # grad 'events' - r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( + r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) @@ -198,7 +200,7 @@ def test_heter(self, shape, transpose): events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 heter_data = bm.as_jax(rng.random(indices.shape)) - r1 = bm.event.csrmv(heter_data, indices, indptr, events, + r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) @@ -208,7 +210,7 @@ def test_heter(self, shape, transpose): r3 = (events @ dense) if transpose else (dense @ events) self.assertTrue(bm.allclose(r1, r3)) - r4 = bm.event.csrmv(heter_data, indices, indptr, events.astype(float), + r4 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r4)) @@ -239,7 +241,7 @@ def test_heter_vmap(self, shape, transpose): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) f2 = jax.vmap( partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), @@ -249,7 +251,7 @@ def test_heter_vmap(self, shape, transpose): # vmap 'events' data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, + f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, shape=shape, transpose=transpose)) f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr, shape=shape, transpose=transpose)) @@ -257,7 +259,7 @@ def test_heter_vmap(self, shape, transpose): self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, + f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee, shape=shape, transpose=transpose)) @@ -295,7 +297,7 @@ def test_heter_grad(self, shape, transpose): # grad 'data' data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(bm.event.csrmv))( + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( data, indices, indptr, events, shape=shape, transpose=transpose) r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) @@ -309,7 +311,7 @@ def test_heter_grad(self, shape, transpose): self.assertTrue(bm.allclose(r1, r3)) # grad 'events' - r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( + r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py index c81aee7c0..781e3c91c 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py @@ -19,6 +19,8 @@ def func(*args, **kwargs): return func +brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') +taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -42,8 +44,8 @@ def test_homo(self, transpose, shape, homo_data): events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data - r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = bm.event.csrmv_taichi(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -65,24 +67,24 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' - f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, + f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmv_taichi, homo_data, indices, indptr, + f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.event.csrmv_taichi(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 @@ -110,16 +112,16 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(bm.event.csrmv))( + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(bm.event.csrmv_taichi))( + r2 = jax.grad(sum_op(taichi_csr_matvec))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( + r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -141,9 +143,9 @@ def test_heter(self, shape, transpose): events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 heter_data = bm.as_jax(rng.random(indices.shape)) - r1 = bm.event.csrmv(heter_data, indices, indptr, events, + r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = bm.event.csrmv_taichi(heter_data, indices, indptr, events, + r2 = taichi_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -167,26 +169,26 @@ def test_heter_vmap(self, shape, transpose): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, + f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmv_taichi, data, indices, indptr, + f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, + f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.event.csrmv_taichi(dd, indices, indptr, ee, + f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 @@ -215,22 +217,22 @@ def test_heter_grad(self, shape, transpose): # grad 'data' data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(bm.event.csrmv))( + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(bm.event.csrmv_taichi))( + r2 = jax.grad(sum_op(taichi_csr_matvec))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( + r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - r5 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( + r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index 439324152..a79cdc982 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,5 +1,3 @@ from ._matvec import * -from ._matvec_taichi import * -from ._event_matvec import * -from ._event_matvec_taichi import * +from ._event_matvec import * \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index d739919f7..e20cf804c 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -3,6 +3,7 @@ from functools import partial from typing import Tuple, Optional +import brainpy.math as bm import jax import numpy as np from jax import numpy as jnp, dtypes @@ -10,18 +11,29 @@ from jax.interpreters import xla, ad from jax.lib import xla_client -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p, mv_prob_uniform_p, mv_prob_normal_p, mv_prob_homo, mv_prob_uniform, - mv_prob_normal) + mv_prob_normal, + _general_checking, + raw_mv_prob_homo, + raw_mv_prob_uniform, + raw_mv_prob_normal, + _mv_prob_homo_transpose, + _mv_prob_uniform_transpose, + _mv_prob_normal_transpose, + _reverse) from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import register_general_batching +from brainpy._src.math.op_register import register_general_batching, XLACustomOp +from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) from brainpy.errors import GPUOperatorNotFound +ti = import_taichi() + __all__ = [ 'event_mv_prob_homo', 'event_mv_prob_uniform', @@ -38,6 +50,94 @@ def event_mv_prob_homo( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, +) -> jax.Array: + if method is None: + if bm.get_platform() == 'cpu': + method = 'taichi' + elif bm.get_platform() == 'gpu': + if outdim_parallel: + method = 'brainpylib' + else: + method = 'taichi' + + if method == 'taichi': + return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + elif method == 'brainpylib': + return event_mv_prob_homo_brainpylib(events, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + else: + raise ValueError(f'Unknown method {method}.') + +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + if method is None: + if bm.get_platform() == 'cpu': + method = 'taichi' + elif bm.get_platform() == 'gpu': + if outdim_parallel: + method = 'brainpylib' + else: + method = 'taichi' + + if method == 'taichi': + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + elif method == 'brainpylib': + return event_mv_prob_uniform_brainpylib(events, w_low, w_high, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + else: + raise ValueError(f'Unknown method {method}.') + +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + if method is None: + if bm.get_platform() == 'cpu': + method = 'taichi' + elif bm.get_platform() == 'gpu': + if outdim_parallel: + method = 'brainpylib' + else: + method = 'taichi' + + if method == 'taichi': + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + elif method == 'brainpylib': + return event_mv_prob_uniform_brainpylib(events, w_mu, w_sigma, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + else: + raise ValueError(f'Unknown method {method}.') + +### BRAINPYLIB ### + +def event_mv_prob_homo_brainpylib( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, ) -> jax.Array: events = as_jax(events) weight = jnp.atleast_1d(as_jax(weight)) @@ -57,10 +157,10 @@ def event_mv_prob_homo( return r -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ +event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ -def event_mv_prob_uniform( +def event_mv_prob_uniform_brainpylib( events: jax.Array, w_low: float, w_high: float, @@ -90,10 +190,10 @@ def event_mv_prob_uniform( outdim_parallel=outdim_parallel)[0] -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ +event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ -def event_mv_prob_normal( +def event_mv_prob_normal_brainpylib( events: jax.Array, w_mu: float, w_sigma: float, @@ -122,8 +222,7 @@ def event_mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] - -event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ +event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ def _event_matvec_prob_homo_abstract( @@ -665,3 +764,1257 @@ def _event_matvec_prob_normal_transpose( register_general_batching(event_mv_prob_normal_p) ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose + + +### TAICHI ### + +def event_mv_prob_homo_taichi( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + events = as_jax(events) + if isinstance(weight, float): weight = as_jax(weight) + weight = jnp.atleast_1d(as_jax(weight)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + + +def event_mv_prob_uniform_taichi( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + events = as_jax(events) + if isinstance(w_low, float): w_low = as_jax(w_low) + if isinstance(w_high, float): w_high = as_jax(w_high) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + +def event_mv_prob_normal_taichi( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a normal distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + events = as_jax(events) + if isinstance(w_mu, float): w_mu = as_jax(w_mu) + if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + +# ------------- +# CPU function +# ------------- +# For each non-zero event value, it generates a random key using a +# function lfsr88_key and then uses this key to compute random integers +# and update the out array based on the computed indices and weight. +# +# The function is likely designed to be parallelized. + + +@ti.kernel +def _event_mv_prob_homo_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_homo_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col]: + r += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +# ------------- +# GPU function +# ------------- +# Contrary to the CPU functions, for each column, +# this function will 32 threads (one warp) to make +# the just-in-time random generation parallelized. + + +@ti.kernel +def _event_mv_prob_homo_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_homo_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison without if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + +def _reverse(shape): + return shape[::-1] + +# ------------- +# CPU function +# ------------- +# For each non-zero event value, it generates a random key using a +# function lfsr88_key and then uses this key to compute random integers +# and update the out array based on the computed indices and weight. +# +# The function is likely designed to be parallelized. + + +@ti.kernel +def _event_mv_prob_homo_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_homo_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col] != 0.: + r += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + +# ------------- +# GPU function +# ------------- +# Contrary to the CPU functions, for each column, +# this function will 32 threads (one warp) to make +# the just-in-time random generation parallelized. + + +@ti.kernel +def _event_mv_prob_homo_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_homo_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison with if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _event_mv_prob_homo_jvp_events( + evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(evt_dot, weight, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_homo_jvp_weight( + w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(events, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + +def raw_event_mv_prob_homo( + events: jax.Array, + weight: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_outdim_parallel_bool_p + else: + prim = _event_mv_prob_homo_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_bool_p + else: + prim = _event_mv_prob_homo_p + + return prim(events, + weight, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_homo_jvp_events, + _event_mv_prob_homo_jvp_weight, + None, + None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + +# outdim_parallel = True, events.dtype = jnp.bool_ +_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu +) + +# outdim_parallel = False, events.dtype = jnp.bool_ +_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_bool_cpu, + gpu_kernel=_event_mv_prob_homo_bool_gpu +) + +# outdim_parallel = True, events.dtype != jnp.bool_ +_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu +) + +# outdim_parallel = False, events.dtype != jnp.bool_ +_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_cpu, + gpu_kernel=_event_mv_prob_homo_gpu +) + + +@ti.kernel +def _event_mv_prob_uniform_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_uniform_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + if events[i_col]: + r += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _event_mv_prob_uniform_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_uniform_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += row_v * events[i_col] # TODO: speed comparison without if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +@ti.kernel +def _event_mv_prob_uniform_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_uniform_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + if events[i_col] != 0.: + r += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + +@ti.kernel +def _event_mv_prob_uniform_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_uniform_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += row_v * events[i_col] # TODO: speed comparison with if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _event_mv_prob_uniform_jvp_events( + evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_uniform_jvp_w_low( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_uniform_jvp_w_high( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def raw_event_mv_prob_uniform( + events: jax.Array, + w_low: jax.Array, # vector with size 1 + w_high: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_outdim_parallel_bool_p + else: + prim = _event_mv_prob_uniform_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_bool_p + else: + prim = _event_mv_prob_uniform_p + + return prim(events, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_uniform_jvp_events, + _event_mv_prob_uniform_jvp_w_low, + _event_mv_prob_uniform_jvp_w_high, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + +# outdim_parallel = True, events.dtype = jnp.bool_ +_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu +) + +# outdim_parallel = False, events.dtype = jnp.bool_ +_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_bool_gpu +) + +# outdim_parallel = True, events.dtype != jnp.bool_ +_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu +) + +# outdim_parallel = False, events.dtype != jnp.bool_ +_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_cpu, + gpu_kernel=_event_mv_prob_uniform_gpu +) + + +@ti.kernel +def _event_mv_prob_normal_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_normal_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + if events[i_col]: + r += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _event_mv_prob_normal_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_normal_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += row_v * events[i_col] # TODO: speed comparison without if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +@ti.kernel +def _event_mv_prob_normal_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_normal_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + if events[i_col] != 0.: + r += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _event_mv_prob_normal_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_normal_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += row_v * events[i_col] # TODO: speed comparison with if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _event_mv_prob_normal_jvp_events( + evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_normal_jvp_w_mu( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_normal_jvp_w_sigma( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def raw_event_mv_prob_normal( + events: jax.Array, + w_mu: jax.Array, # vector with size 1 + w_sigma: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_outdim_parallel_bool_p + else: + prim = _event_mv_prob_normal_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_bool_p + else: + prim = _event_mv_prob_normal_p + + return prim(events, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_normal_jvp_events, + _event_mv_prob_normal_jvp_w_mu, + _event_mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + +# outdim_parallel = True, events.dtype = jnp.bool_ +_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu +) + +# outdim_parallel = False, events.dtype = jnp.bool_ +_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_bool_cpu, + gpu_kernel=_event_mv_prob_normal_bool_gpu +) + +# outdim_parallel = True, events.dtype != jnp.bool_ +_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu +) + +# outdim_parallel = False, events.dtype != jnp.bool_ +_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_cpu, + gpu_kernel=_event_mv_prob_normal_gpu +) diff --git a/brainpy/_src/math/jitconn/_event_matvec_taichi.py b/brainpy/_src/math/jitconn/_event_matvec_taichi.py deleted file mode 100644 index 8346607aa..000000000 --- a/brainpy/_src/math/jitconn/_event_matvec_taichi.py +++ /dev/null @@ -1,1277 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Tuple, Optional - -import jax -import numpy as np -from jax import numpy as jnp - -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_uniform, lfsr88_normal, lfsr88_random_integers) -from ._matvec_taichi import (_general_checking, raw_mv_prob_homo, raw_mv_prob_uniform, raw_mv_prob_normal, - _mv_prob_homo_transpose, _mv_prob_uniform_transpose, _mv_prob_normal_transpose, - _reverse) - -ti = import_taichi() - -__all__ = [ - 'event_mv_prob_homo_taichi', - 'event_mv_prob_uniform_taichi', - 'event_mv_prob_normal_taichi', -] - - -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - if events[i_col]: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - if events[i_col] != 0.: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_homo_jvp_events( - evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(evt_dot, weight, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_homo_jvp_weight( - w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(events, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -def raw_event_mv_prob_homo( - events: jax.Array, - weight: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_outdim_parallel_bool_p - else: - prim = _event_mv_prob_homo_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_bool_p - else: - prim = _event_mv_prob_homo_p - - return prim(events, - weight, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_mv_prob_homo_taichi( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - events = as_jax(events) - if isinstance(weight, float): weight = as_jax(weight) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_homo_jvp_events, - _event_mv_prob_homo_jvp_weight, - None, - None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_bool_cpu, - gpu_kernel=_event_mv_prob_homo_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_cpu, - gpu_kernel=_event_mv_prob_homo_gpu -) - - -@ti.kernel -def _event_mv_prob_uniform_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_uniform_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_uniform_jvp_events( - evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_uniform_jvp_w_low( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_uniform_jvp_w_high( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def raw_event_mv_prob_uniform( - events: jax.Array, - w_low: jax.Array, # vector with size 1 - w_high: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_outdim_parallel_bool_p - else: - prim = _event_mv_prob_uniform_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_bool_p - else: - prim = _event_mv_prob_uniform_p - - return prim(events, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_mv_prob_uniform_taichi( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - events = as_jax(events) - if isinstance(w_low, float): w_low = as_jax(w_low) - if isinstance(w_high, float): w_high = as_jax(w_high) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_uniform_jvp_events, - _event_mv_prob_uniform_jvp_w_low, - _event_mv_prob_uniform_jvp_w_high, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_cpu, - gpu_kernel=_event_mv_prob_uniform_gpu -) - - -@ti.kernel -def _event_mv_prob_normal_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_normal_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_normal_jvp_events( - evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_mu( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_sigma( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def raw_event_mv_prob_normal( - events: jax.Array, - w_mu: jax.Array, # vector with size 1 - w_sigma: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_outdim_parallel_bool_p - else: - prim = _event_mv_prob_normal_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_bool_p - else: - prim = _event_mv_prob_normal_p - - return prim(events, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_mv_prob_normal_taichi( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - events = as_jax(events) - if isinstance(w_mu, float): w_mu = as_jax(w_mu) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_normal_jvp_events, - _event_mv_prob_normal_jvp_w_mu, - _event_mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_bool_cpu, - gpu_kernel=_event_mv_prob_normal_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_cpu, - gpu_kernel=_event_mv_prob_normal_gpu -) diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index cad95924d..9076ce311 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -4,6 +4,7 @@ from functools import partial from typing import Tuple, Optional, Union +import brainpy.math as bm import jax import numpy as np from jax import numpy as jnp, dtypes @@ -11,12 +12,15 @@ from jax.interpreters import xla, ad from jax.lib import xla_client -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_register import register_general_batching +from brainpy._src.math.op_register import register_general_batching, XLACustomOp +from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) from brainpy.errors import GPUOperatorNotFound +ti = import_taichi() + __all__ = [ 'mv_prob_homo', 'mv_prob_uniform', @@ -33,6 +37,241 @@ def mv_prob_homo( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, + method: str = None, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + if method is None: + if bm.get_platform() == 'cpu': + method = 'taichi' + elif bm.get_platform() == 'gpu': + if outdim_parallel: + method = 'brainpylib' + else: + method = 'taichi' + + if method == 'taichi': + return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + elif method == 'brainpylib': + return mv_prob_homo_brainpylib(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + else: + raise ValueError(f'Unknown method {method}.') + + + +def mv_prob_uniform( + vector: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + if method is None: + if bm.get_platform() == 'cpu': + method = 'taichi' + elif bm.get_platform() == 'gpu': + if outdim_parallel: + method = 'brainpylib' + else: + method = 'taichi' + + if method == 'taichi': + return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + elif method == 'brainpylib': + return mv_prob_uniform_brainpylib(vector, w_low, w_high, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + else: + raise ValueError(f'Unknown method {method}.') + + +def mv_prob_normal( + vector: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a normal distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + if method is None: + if bm.get_platform() == 'cpu': + method = 'taichi' + elif bm.get_platform() == 'gpu': + if outdim_parallel: + method = 'brainpylib' + else: + method = 'taichi' + + if method == 'taichi': + return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + elif method == 'brainpylib': + return mv_prob_uniform_brainpylib(vector, w_mu, w_sigma, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + else: + raise ValueError(f'Unknown method {method}.') + + +### BRAINYPLIB ### + +def mv_prob_homo_brainpylib( + vector: Union[Array, jax.Array], + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, ) -> jax.Array: r"""Perform the :math:`y=M@v` operation, where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. @@ -100,7 +339,7 @@ def mv_prob_homo( )[0] -def mv_prob_uniform( +def mv_prob_uniform_brainpylib( vector: jax.Array, w_low: float, w_high: float, @@ -180,7 +419,7 @@ def mv_prob_uniform( outdim_parallel=outdim_parallel)[0] -def mv_prob_normal( +def mv_prob_normal_brainpylib( vector: jax.Array, w_mu: float, w_sigma: float, @@ -260,6 +499,7 @@ def mv_prob_normal( outdim_parallel=outdim_parallel)[0] + def _matvec_prob_homo_abstract( vector, weight, clen, seed, *, shape, transpose, outdim_parallel ): @@ -817,3 +1057,889 @@ def _matvec_prob_normal_transpose( register_general_batching(mv_prob_normal_p) ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose + + +### TAICHI ### +def mv_prob_homo_taichi( + vector: Union[Array, jax.Array], + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of + the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``. + + Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same + of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + vector = as_jax(vector) + if isinstance(weight, float): + weight = as_jax(weight, dtype=vector.dtype) + weight = jnp.atleast_1d(as_jax(weight)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.asarray(seed, dtype=jnp.uint32) + seed = jnp.atleast_1d(seed) + return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + +def mv_prob_uniform_taichi( + vector: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + vector = as_jax(vector) + if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) + if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + +def mv_prob_normal_taichi( + vector: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a normal distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + vector = as_jax(vector) + if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) + if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + +def _reverse(shape): + return shape[::-1] + + +@ti.kernel +def _mv_prob_homo_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + v = vector[i_col] * weight0 + while i_row < num_row: + out[i_row] += v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_homo_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r * weight0 + + +@ti.kernel +def _mv_prob_homo_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_homo_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += weight0 * r # TODO: warp-level reduction + + +def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_homo_transpose( + ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), weight, clen, seed + else: + dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, weight, clen, seed + elif ad.is_undefined_primal(weight): + if type(ct) is ad.Zero: + return vector, ad.Zero(weight), clen, seed + else: + row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + dw = jnp.sum(row * vector, keepdims=True) + return vector, dw, clen, seed + else: + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + +def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + if vector.ndim != 1: + raise ValueError('vector should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + + assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + + for weight in weights: + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + out_shape = (shape[1],) + if vector.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') + shape = _reverse(shape) + else: + if vector.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') + out_shape = (shape[0],) + + return shape, out_shape + + +def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + +def raw_mv_prob_homo( + vector: jax.Array, + weight: jax.Array, # vector with size 1 + clen: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + prim = _mv_prob_homo_outdim_parallel_p + else: + prim = _mv_prob_homo_p + + return prim(vector, + weight, + clen, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + +# outdim_parallel = True +_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) + +# outdim_parallel = False +_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, + gpu_kernel=_mv_prob_homo_gpu) + + +@ti.kernel +def _mv_prob_uniform_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_uniform_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _mv_prob_uniform_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_uniform_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_uniform_transpose( + ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_low, w_high, clen, seed + else: + dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_low, w_high, clen, seed + else: + assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' + assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + +def raw_mv_prob_uniform( + vector: jax.Array, + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + prim = _mv_prob_uniform_outdim_parallel_p + else: + prim = _mv_prob_uniform_p + + return prim(vector, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_uniform_jvp_vector, + _mv_prob_uniform_jvp_wlow, + _mv_prob_uniform_jvp_whigh, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + +# outdim_parallel = True +_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu +) + +# outdim_parallel = False +_mv_prob_uniform_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_cpu, + gpu_kernel=_mv_prob_uniform_gpu +) + + +@ti.kernel +def _mv_prob_normal_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_normal_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _mv_prob_normal_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_normal_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_normal_transpose( + ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_mu, w_sigma, clen, seed + else: + dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_mu, w_sigma, clen, seed + else: + assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' + assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + +def raw_mv_prob_normal( + vector: jax.Array, + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + prim = _mv_prob_normal_outdim_parallel_p + else: + prim = _mv_prob_normal_p + + return prim(vector, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_normal_jvp_vector, + _mv_prob_normal_jvp_w_mu, + _mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + +# outdim_parallel = True +_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_mv_prob_normal_outdim_parallel_gpu +) + +# outdim_parallel = False +_mv_prob_normal_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_cpu, + gpu_kernel=_mv_prob_normal_gpu +) diff --git a/brainpy/_src/math/jitconn/_matvec_taichi.py b/brainpy/_src/math/jitconn/_matvec_taichi.py deleted file mode 100644 index beaf2c383..000000000 --- a/brainpy/_src/math/jitconn/_matvec_taichi.py +++ /dev/null @@ -1,911 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Tuple, Optional, Union - -import jax -import numpy as np -from jax import numpy as jnp -from jax.interpreters import ad - -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) - -ti = import_taichi() - -__all__ = [ - 'mv_prob_homo_taichi', - 'mv_prob_uniform_taichi', - 'mv_prob_normal_taichi', -] - - -def _reverse(shape): - return shape[::-1] - - -@ti.kernel -def _mv_prob_homo_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - v = vector[i_col] * weight0 - while i_row < num_row: - out[i_row] += v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r * weight0 - - -@ti.kernel -def _mv_prob_homo_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += weight0 * r # TODO: warp-level reduction - - -def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed - else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed - else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - - -def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def mv_prob_homo_taichi( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of - the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``. - - Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same - of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(weight, float): - weight = as_jax(weight, dtype=vector.dtype) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.asarray(seed, dtype=jnp.uint32) - seed = jnp.atleast_1d(seed) - return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) - -# outdim_parallel = False -_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, - gpu_kernel=_mv_prob_homo_gpu) - - -@ti.kernel -def _mv_prob_uniform_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_uniform_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_uniform_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_uniform_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_low, w_high, clen, seed - else: - dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_low, w_high, clen, seed - else: - assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' - assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p - else: - prim = _mv_prob_uniform_p - - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def mv_prob_uniform_taichi( - vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) - if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_uniform_jvp_vector, - _mv_prob_uniform_jvp_wlow, - _mv_prob_uniform_jvp_whigh, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_uniform_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_cpu, - gpu_kernel=_mv_prob_uniform_gpu -) - - -@ti.kernel -def _mv_prob_normal_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_normal_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed - else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p - - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def mv_prob_normal_taichi( - vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_normal_jvp_vector, - _mv_prob_normal_jvp_w_mu, - _mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_normal_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_cpu, - gpu_kernel=_mv_prob_normal_gpu -) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 556213e89..24f66878e 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from functools import partial import jax import jax.numpy as jnp @@ -19,6 +20,12 @@ # (1000, 10), (1000, 2)] +brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib') +taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi') +brainpylib_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='brainpylib') +taichi_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='taichi') +brainpylib_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='brainpylib') +taichi_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='taichi') class Test_event_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -53,7 +60,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_homo(events, + r1 = brainpylib_mv_prob_homo(events, homo_data, conn_prob=prob, shape=shape, @@ -62,7 +69,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_homo(events, + r2 = brainpylib_mv_prob_homo(events, homo_data, conn_prob=prob, shape=shape, @@ -72,7 +79,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_homo(events, + r3 = brainpylib_mv_prob_homo(events, homo_data, conn_prob=prob, shape=(shape[1], shape[0]), @@ -120,7 +127,7 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=Tru weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: bm.jitconn.event_mv_prob_homo( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, transpose=transpose, outdim_parallel=outdim_parallel ) @@ -164,7 +171,7 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: bm.jitconn.event_mv_prob_homo( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose ).sum(), @@ -231,7 +238,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_uniform(events, + r1 = brainpylib_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -241,7 +248,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_uniform(events, + r2 = brainpylib_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -252,7 +259,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_uniform(events, + r3 = brainpylib_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -302,7 +309,7 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, events = events.astype(float) f1 = jax.vmap( - lambda e: bm.jitconn.event_mv_prob_uniform(e, + lambda e: brainpylib_mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -352,7 +359,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = events.astype(float) f1 = jax.grad( - lambda e, w_high: bm.jitconn.event_mv_prob_uniform( + lambda e, w_high: brainpylib_mv_prob_uniform( e, w_low=0., w_high=w_high, @@ -413,7 +420,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_normal(events, + r1 = brainpylib_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -423,7 +430,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_normal(events, + r2 = brainpylib_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -434,7 +441,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_normal(events, + r3 = brainpylib_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -486,7 +493,7 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, if not bool_event: events = events.astype(float) - f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e, + f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -536,7 +543,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x f1 = jax.jit( jax.grad( - lambda e, w_sigma: bm.jitconn.event_mv_prob_normal( + lambda e, w_sigma: brainpylib_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py index e42434e95..62b665f47 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +from functools import partial import jax import jax.numpy as jnp @@ -10,6 +10,12 @@ shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] +brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib') +taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi') +brainpylib_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='brainpylib') +taichi_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='taichi') +brainpylib_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='brainpylib') +taichi_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='taichi') class Test_event_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -44,7 +50,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_homo_taichi(events, + r1 = taichi_mv_prob_homo(events, homo_data, conn_prob=prob, shape=shape, @@ -53,7 +59,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_homo_taichi(events, + r2 = taichi_mv_prob_homo(events, homo_data, conn_prob=prob, shape=shape, @@ -63,7 +69,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_homo_taichi(events, + r3 = taichi_mv_prob_homo(events, homo_data, conn_prob=prob, shape=(shape[1], shape[0]), @@ -111,7 +117,7 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=Tru weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: bm.jitconn.event_mv_prob_homo_taichi( + lambda event, data: taichi_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, transpose=transpose, outdim_parallel=outdim_parallel )[0] @@ -155,7 +161,7 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: bm.jitconn.event_mv_prob_homo_taichi( + lambda event, data: taichi_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), argnums=0 @@ -221,7 +227,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_uniform_taichi(events, + r1 = taichi_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -231,7 +237,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_uniform_taichi(events, + r2 = taichi_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -242,7 +248,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_uniform_taichi(events, + r3 = taichi_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -292,7 +298,7 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, events = events.astype(float) f1 = jax.vmap( - lambda e: bm.jitconn.event_mv_prob_uniform_taichi(e, + lambda e: taichi_mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -342,7 +348,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = events.astype(float) f1 = jax.grad( - lambda e, w_high: bm.jitconn.event_mv_prob_uniform_taichi( + lambda e, w_high: taichi_mv_prob_uniform( e, w_low=0., w_high=w_high, @@ -403,7 +409,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_normal_taichi(events, + r1 = taichi_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -413,7 +419,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_normal_taichi(events, + r2 = taichi_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -424,7 +430,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_normal_taichi(events, + r3 = taichi_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -476,7 +482,7 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, if not bool_event: events = events.astype(float) - f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal_taichi(e, + f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -526,7 +532,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x f1 = jax.jit( jax.grad( - lambda e, w_sigma: bm.jitconn.event_mv_prob_normal_taichi( + lambda e, w_sigma: taichi_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 91c48fc66..25656f9ab 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from functools import partial import jax import jax.numpy as jnp @@ -18,6 +19,12 @@ (1000, 10), (1000, 2)] +brainpylib_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='brainpylib') +taichi_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='taichi') +brainpylib_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='brainpylib') +taichi_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='taichi') +brainpylib_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='brainpylib') +taichi_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='taichi') class Test_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -59,7 +66,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_homo(vector, + r1 = brainpylib_mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -67,7 +74,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_homo(vector, + r2 = brainpylib_mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -76,7 +83,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non transpose=transpose) self.assertTrue(jnp.allclose(r1, r2)) - r2 = bm.jitconn.mv_prob_homo(vector, + r2 = brainpylib_mv_prob_homo(vector, homo_data, conn_prob=prob, shape=(shape[1], shape[0]), @@ -121,7 +128,7 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64 weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: bm.jitconn.mv_prob_homo( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose @@ -166,7 +173,7 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: bm.jitconn.mv_prob_homo( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, @@ -223,7 +230,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_uniform(events, + r1 = brainpylib_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -232,7 +239,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_uniform(events, + r2 = brainpylib_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -245,7 +252,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s print(r1, r2) self.assertTrue(c) - r2 = bm.jitconn.mv_prob_uniform(events, + r2 = brainpylib_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -291,7 +298,7 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e, + f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -340,7 +347,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) f1 = jax.grad( - lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform( + lambda e, w_low, w_high: brainpylib_mv_prob_uniform( e, w_low=w_low, w_high=w_high, @@ -400,7 +407,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_normal(events, + r1 = brainpylib_mv_prob_uniform(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -409,7 +416,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_normal(events, + r2 = brainpylib_mv_prob_uniform(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -422,7 +429,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se print(r1, r2) self.assertTrue(c) - r2 = bm.jitconn.mv_prob_normal(events, + r2 = brainpylib_mv_prob_uniform(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -469,7 +476,7 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e, + f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -521,7 +528,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x events = events.astype(float) f1 = jax.grad( - lambda e, w_sigma: bm.jitconn.mv_prob_normal( + lambda e, w_sigma: brainpylib_mv_prob_uniform( e, w_mu=0., w_sigma=w_sigma, diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py index 380db3cf5..8f42831d5 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +from functools import partial import jax import jax.numpy as jnp @@ -10,6 +10,12 @@ shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] +brainpylib_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='brainpylib') +taichi_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='taichi') +brainpylib_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='brainpylib') +taichi_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='taichi') +brainpylib_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='brainpylib') +taichi_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='taichi') class Test_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -51,7 +57,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_homo_taichi(vector, + r1 = taichi_mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -59,7 +65,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_homo_taichi(vector, + r2 = taichi_mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -68,7 +74,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non transpose=transpose) self.assertTrue(jnp.allclose(r1, r2)) - r2 = bm.jitconn.mv_prob_homo_taichi(vector, + r2 = taichi_mv_prob_homo(vector, homo_data, conn_prob=prob, shape=(shape[1], shape[0]), @@ -111,7 +117,7 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64 weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: bm.jitconn.mv_prob_homo_taichi( + lambda event, data: taichi_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose @@ -156,7 +162,7 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: bm.jitconn.mv_prob_homo_taichi( + lambda event, data: taichi_mv_prob_homo( event, data, conn_prob=prob, shape=shape, @@ -213,7 +219,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_uniform_taichi(events, + r1 = taichi_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -222,7 +228,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_uniform_taichi(events, + r2 = taichi_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -235,7 +241,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s print(r1, r2) self.assertTrue(c) - r2 = bm.jitconn.mv_prob_uniform_taichi(events, + r2 = taichi_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -281,7 +287,7 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform_taichi(e, + f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -330,7 +336,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) f1 = jax.grad( - lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform_taichi( + lambda e, w_low, w_high: taichi_mv_prob_uniform( e, w_low=w_low, w_high=w_high, @@ -390,7 +396,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_normal_taichi(events, + r1 = taichi_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -399,7 +405,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_normal_taichi(events, + r2 = taichi_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -412,7 +418,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se print(r1, r2) self.assertTrue(c) - r2 = bm.jitconn.mv_prob_normal_taichi(events, + r2 = taichi_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -459,7 +465,7 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal_taichi(e, + f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -512,7 +518,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x events = events.astype(float) f1 = jax.grad( - lambda e, w_sigma: bm.jitconn.mv_prob_normal_taichi( + lambda e, w_sigma: taichi_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index cd94d0621..d45f2c80b 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,7 +1,6 @@ from ._coo_mv import * from ._csr_mv import * -from ._csr_mv_taichi import * from ._utils import * from ._bsr_mv import * from ._bsr_mm import * diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index d874ad901..aa29ed36a 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -13,20 +13,84 @@ from jax.lib import xla_client from jaxlib import gpu_sparse -from brainpy._src.dependency_check import import_brainpylib_gpu_ops +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching) + register_general_batching, + XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import GPUOperatorNotFound +ti = import_taichi() + __all__ = [ 'csrmv', ] def csrmv( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, + method: str = None, +): + """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + vector: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple of int + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + method: str + The method used to compute Matrix-Vector Multiplication. Default is ``taichi``. + The candidate methods are: + + - ``taichi``: using Taichi kernel. + - ``brainpylib``: using cuSPARSE library. + - ``cusparse``: using cuSPARSE library. + - ``scalar``: + - ``vector``: + - ``adaptive``: + + Returns + ------- + y : ndarry + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + if method is None: + method = 'taichi' + + if method == 'taichi': + return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) + elif method == 'brainpylib': + return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method='cusparse') + else: + return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method) + + +### BRAINPYLIB ### + +def csrmv_brainpylib( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], indptr: Union[jnp.ndarray, Array], @@ -107,7 +171,6 @@ def csrmv( else: raise ValueError(f'Only support methods: cusparse, scalar, vector, and adaptive. But we got {method}.') - def _csrmv_abstract(data, indices, indptr, vector, *, shape, transpose): if data.dtype not in [jnp.float32, jnp.float64]: raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.') @@ -466,3 +529,288 @@ def _csrmv_adaptive_transpose(ct, data, indices, indptr, vector, *, shape, trans partial(_csrmv_jvp_vec, _csrmv_adaptive_p), ) ad.primitive_transposes[_csrmv_adaptive_p] = _csrmv_adaptive_transpose register_general_batching(_csrmv_adaptive_p) + + +### TAICHI ### + +def csrmv_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +) -> jax.Array: + """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + vector: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple of int + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + + Returns + ------- + y : ndarry + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + + data = jnp.atleast_1d(as_jax(data)) + indices = as_jax(indices) + indptr = as_jax(indptr) + vector = as_jax(vector) + + if vector.dtype == jnp.bool_: + vector = as_jax(vector, dtype=data.dtype) + + if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]: + raise TypeError('Only support float16, float32 or float64 type. ' + f'But we got {data.dtype}.') + if data.dtype != vector.dtype: + raise TypeError('The types of data and vector should be the same. ' + f'But we got {data.dtype} != {vector.dtype}.') + assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + + # if the shape of indices is (0,), then we return a zero vector + if indices.shape[0] == 0: + return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) + + return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] + + +# ------------- +# CPU operators +# ------------- + + +@ti.kernel +def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += value * vector[row_i] + + +@ti.kernel +def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += vector[row_i] * values[j] + + +@ti.kernel +def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += vector[col_indices[j]] + out[row_i] = r * value + + +@ti.kernel +def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + out[row_i] = r + + +# ------------- +# GPU operators +# ------------- + + +@ti.kernel +def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += value * vector[row_i] + j += 32 + + +@ti.kernel +def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += vector[col_indices[j]] + j += 32 + out[row_i] += value * r + + +@ti.kernel +def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += values[j] * vector[row_i] + j += 32 + + +@ti.kernel +def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += values[j] * vector[col_indices[j]] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) + + +def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) + + +def _sparse_csr_matvec_transpose( + ct, data, indices, indptr, vector, *, outs, transpose, shape, +): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(vector): + ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] + ct_data = jnp.inner(ct[0], ct_data) + else: + row, col = csr_to_coo(indices, indptr) + ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] + + return ct_data, indices, indptr, vector + +def raw_csrmv_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + out_shape = shape[1] if transpose else shape[0] + if transpose: + if data.shape[0] == 1: + prim = _csr_matvec_transpose_homo_p + else: + prim = _csr_matvec_transpose_heter_p + else: + if data.shape[0] == 1: + prim = _csr_matvec_homo_p + else: + prim = _csr_matvec_heter_p + + return prim(data, + indices, + indptr, + vector, + outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], + transpose=transpose, + shape=shape) + + +def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) + prim.def_transpose_rule(_sparse_csr_matvec_transpose) + return prim + + +# transpose homo +_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) + +# no transpose homo +_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, + gpu_kernel=_sparse_csr_matvec_homo_gpu) + +# transpose heter +_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) + +# no transpose heter +_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, + gpu_kernel=_sparse_csr_matvec_heter_gpu) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/_csr_mv_taichi.py b/brainpy/_src/math/sparse/_csr_mv_taichi.py deleted file mode 100644 index 84cae5554..000000000 --- a/brainpy/_src/math/sparse/_csr_mv_taichi.py +++ /dev/null @@ -1,304 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Union, Tuple - -import jax -from jax import numpy as jnp -from jax.interpreters import ad - -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._utils import csr_to_coo - -ti = import_taichi() - -__all__ = [ - 'csrmv_taichi', -] - - -# ------------- -# CPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += value * vector[row_i] - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += vector[row_i] * values[j] - - -@ti.kernel -def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += vector[col_indices[j]] - out[row_i] = r * value - - -@ti.kernel -def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += value * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += vector[col_indices[j]] - j += 32 - out[row_i] += value * r - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += values[j] * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += values[j] * vector[col_indices[j]] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_transpose( - ct, data, indices, indptr, vector, *, outs, transpose, shape, -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(vector): - ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] - ct_data = jnp.inner(ct[0], ct_data) - else: - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] - - return ct_data, indices, indptr, vector - -def raw_csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -): - out_shape = shape[1] if transpose else shape[0] - if transpose: - if data.shape[0] == 1: - prim = _csr_matvec_transpose_homo_p - else: - prim = _csr_matvec_transpose_heter_p - else: - if data.shape[0] == 1: - prim = _csr_matvec_homo_p - else: - prim = _csr_matvec_heter_p - - return prim(data, - indices, - indptr, - vector, - outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -def csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -) -> jax.Array: - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - - data = jnp.atleast_1d(as_jax(data)) - indices = as_jax(indices) - indptr = as_jax(indptr) - vector = as_jax(vector) - - if vector.dtype == jnp.bool_: - vector = as_jax(vector, dtype=data.dtype) - - if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]: - raise TypeError('Only support float16, float32 or float64 type. ' - f'But we got {data.dtype}.') - if data.dtype != vector.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {vector.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - # if the shape of indices is (0,), then we return a zero vector - if indices.shape[0] == 0: - return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - - return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] - - - -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) - prim.def_transpose_rule(_sparse_csr_matvec_transpose) - return prim - - -# transpose homo -_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) - -# no transpose homo -_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, - gpu_kernel=_sparse_csr_matvec_homo_gpu) - -# transpose heter -_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) - -# no transpose heter -_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py index fed665c8d..5e2a644a4 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py @@ -53,7 +53,7 @@ def compare_with_nan_tolerance(a, b, tol=1e-8): vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') - +taichi_csr_matvec = partial(bm.sparse.csrmv, method='taichi') class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -81,7 +81,7 @@ def test_homo(self, transpose, shape, homo_data): vector = bm.as_jax(vector) r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = bm.sparse.csrmv_taichi(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -108,7 +108,7 @@ def test_homo_vmap(self, transpose, shape, v): f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) - f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, + f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(homo_data) r2 = jax.vmap(f1)(homo_data) @@ -140,13 +140,13 @@ def test_homo_grad(self, transpose, shape, homo_data): # grad 'data' r1 = jax.grad(sum_op(vector_csr_matvec))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(bm.sparse.csrmv_taichi))( + r2 = jax.grad(sum_op(taichi_csr_matvec))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, # shape=shape, transpose=transpose).sum(), # argnums=0) - # csr_f2 = jax.grad(lambda a: bm.sparse.csrmv_taichi(a, indices, indptr, vector, + # csr_f2 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, # shape=shape, transpose=transpose)[0].sum(), # argnums=0) # r1 = csr_f1(homo_data) @@ -157,14 +157,14 @@ def test_homo_grad(self, transpose, shape, homo_data): # grad 'vector' r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) @@ -191,7 +191,7 @@ def test_heter(self, transpose, shape): vector = bm.as_jax(vector) r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape) + r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape) self.assertTrue(compare_with_nan_tolerance(r1, r2)) @@ -218,7 +218,7 @@ def test_heter_vmap(self, transpose, shape): f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) - f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, + f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(heter_data) r2 = jax.vmap(f2)(heter_data) @@ -244,20 +244,20 @@ def test_heter_grad(self, transpose, shape): # grad 'data' r1 = jax.grad(sum_op(vector_csr_matvec))( heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(bm.sparse.csrmv_taichi))( + r2 = jax.grad(sum_op(taichi_csr_matvec))( heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 2e9f38039..0a17cae7c 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,6 +1,5 @@ from brainpy._src.math.event import ( csrmv as csrmv, - csrmv_taichi as csrmv_taichi, info as info, ) diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 0ade274e6..90a028b7e 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -6,13 +6,5 @@ mv_prob_homo as mv_prob_homo, mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, - - event_mv_prob_homo_taichi as event_mv_prob_homo_taichi, - event_mv_prob_uniform_taichi as event_mv_prob_uniform_taichi, - event_mv_prob_normal_taichi as event_mv_prob_normal_taichi, - - mv_prob_homo_taichi as mv_prob_homo_taichi, - mv_prob_uniform_taichi as mv_prob_uniform_taichi, - mv_prob_normal_taichi as mv_prob_normal_taichi ) diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 97c585746..1380a9e9c 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,6 +1,5 @@ from brainpy._src.math.sparse import ( csrmv, - csrmv_taichi, coomv, seg_matmul, From bcd9afbbd37112513248a4cf33452d3da416aada Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 28 Jan 2024 19:23:22 +0800 Subject: [PATCH 12/27] [math] Fix bugs --- brainpy/_src/math/event/_csr_matvec.py | 4 ++-- brainpy/_src/math/event/tests/test_event_csrmv.py | 4 ++-- brainpy/_src/math/jitconn/_event_matvec.py | 3 +++ brainpy/_src/math/jitconn/_matvec.py | 2 ++ 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 53e110dd4..80f92827f 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -590,7 +590,7 @@ def _event_csr_matvec_batching_rule(args, axes, *, shape, transpose): def _event_csr_matvec_jvp_values(values_dot, values, indices, indptr, events, *, shape, transpose): - return csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose) + return csrmv_brainpylib(values_dot, indices, indptr, events, shape=shape, transpose=transpose) def _event_csr_matvec_jvp_events(events_dot, values, indices, indptr, events, *, shape, transpose): @@ -608,7 +608,7 @@ def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, t ct_values = ad.Zero(values) else: if values.aval.shape[0] == 1: # scalar - ct_values = csrmv(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose) + ct_values = csrmv_brainpylib(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose) ct_values = jnp.inner(ct, ct_values) else: # heterogeneous values row, col = csr_to_coo(indices, indptr) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 55253ab47..67b77c5a3 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -13,8 +13,8 @@ import pytest is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) +# if platform.system() == 'Windows' and not is_manual_test: +# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index e20cf804c..f09e069e8 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -50,6 +50,7 @@ def event_mv_prob_homo( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, + method: str = None, ) -> jax.Array: if method is None: if bm.get_platform() == 'cpu': @@ -80,6 +81,7 @@ def event_mv_prob_uniform( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, + method: str = None, ) -> jax.Array: if method is None: if bm.get_platform() == 'cpu': @@ -110,6 +112,7 @@ def event_mv_prob_normal( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, + method: str = None, ) -> jax.Array: if method is None: if bm.get_platform() == 'cpu': diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 9076ce311..d6a41c30c 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -115,6 +115,7 @@ def mv_prob_uniform( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, + method: str = None, ) -> jax.Array: r"""Perform the :math:`y=M@v` operation, where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. @@ -193,6 +194,7 @@ def mv_prob_normal( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, + method: str = None, ) -> jax.Array: r"""Perform the :math:`y=M@v` operation, where :math:`M` is just-in-time randomly generated with a normal distribution for its value. From ee018b066ef58467006fd3f0103b0b8d8bd03ac0 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 28 Jan 2024 19:29:09 +0800 Subject: [PATCH 13/27] [dnn] Fix bugs --- brainpy/_src/dnn/linear.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index c3f287f64..f85acbc7b 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -4,14 +4,11 @@ import numbers from typing import Dict, Optional, Union, Callable -from brainpy._src.dependency_check import import_taichi import jax import jax.numpy as jnp import numba import numpy as np -ti = import_taichi() - from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share @@ -579,7 +576,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, + return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) elif x.ndim > 1: @@ -591,7 +588,7 @@ def update(self, x): raise ValueError def _batch_csrmv(self, x): - return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, + return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) @@ -640,7 +637,7 @@ def update(self, x): raise ValueError def _batch_csrmv(self, x): - return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) @@ -856,7 +853,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -871,7 +868,7 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -936,7 +933,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -951,7 +948,7 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1016,7 +1013,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1031,7 +1028,7 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1095,7 +1092,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1110,7 +1107,7 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1175,7 +1172,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1190,7 +1187,7 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1255,7 +1252,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.jitconn.event_mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) @@ -1270,7 +1267,7 @@ def update(self, x): raise ValueError def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) From 85afa71810ad8b256bcbc73a1b552cb045b99867 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 28 Jan 2024 20:31:48 +0800 Subject: [PATCH 14/27] [math] Fix bugs --- brainpy/_src/dnn/linear.py | 2 +- brainpy/_src/math/event/_csr_matvec.py | 71 ++++++++++---------- brainpy/_src/math/sparse/tests/test_csrmv.py | 4 +- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index f85acbc7b..62f98f6c6 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -625,7 +625,7 @@ def __init__( def update(self, x): if x.ndim == 1: - return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) elif x.ndim > 1: diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 80f92827f..1a1851adb 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -28,7 +28,7 @@ from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching, XLACustomOp) -from brainpy._src.math.sparse._csr_mv import csrmv as normal_csrmv +from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) @@ -589,15 +589,15 @@ def _event_csr_matvec_batching_rule(args, axes, *, shape, transpose): return r, 0 -def _event_csr_matvec_jvp_values(values_dot, values, indices, indptr, events, *, shape, transpose): - return csrmv_brainpylib(values_dot, indices, indptr, events, shape=shape, transpose=transpose) +def _event_csr_matvec_jvp_values_brainpylib(values_dot, values, indices, indptr, events, *, shape, transpose): + return normal_csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose) -def _event_csr_matvec_jvp_events(events_dot, values, indices, indptr, events, *, shape, transpose): +def _event_csr_matvec_jvp_events_brainpylib(events_dot, values, indices, indptr, events, *, shape, transpose): return normal_csrmv(values, indices, indptr, events_dot, shape=shape, transpose=transpose) -def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, transpose): +def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, *, shape, transpose): if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): raise ValueError("Cannot transpose with respect to sparse indices.") if ad.is_undefined_primal(events): @@ -621,8 +621,8 @@ def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, t event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p)) xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation -ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values, None, None, _event_csr_matvec_jvp_events) -ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose +ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, _event_csr_matvec_jvp_events_brainpylib) +ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib register_general_batching(event_csr_matvec_p) # batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule @@ -1041,11 +1041,37 @@ def raw_csrmv_taichi( transpose=transpose, shape=shape) +def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) + + +def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + +def _event_csr_matvec_transpose_taichi( + ct, values, indices, indptr, events, *, outs, transpose, shape +): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(events): + ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] + return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) + else: + if type(ct[0]) is ad.Zero: + ct_values = ad.Zero(values) + else: + if values.aval.shape[0] == 1: # scalar + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = jnp.inner(ct[0], ct_values) + else: # heterogeneous values + row, col = csr_to_coo(indices, indptr) + ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] + return ct_values, indices, indptr, events def _define_op(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values, None, None, _event_csr_matvec_jvp_events) - prim.def_transpose_rule(_event_csr_matvec_transpose) + prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) + prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) return prim @@ -1079,30 +1105,3 @@ def _define_op(cpu_kernel, gpu_kernel): -def _event_csr_matvec_jvp_values(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) - - -def _event_csr_matvec_jvp_events(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) - - -def _event_csr_matvec_transpose( - ct, values, indices, indptr, events, *, outs, transpose, shape -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 16bf43a48..b902cd8a0 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -10,8 +10,8 @@ import brainpy.math as bm is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) +# if platform.system() == 'Windows' and not is_manual_test: +# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') From ff846aee65e8dbee2ea276c60583375ecb309822 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 28 Jan 2024 22:25:36 +0800 Subject: [PATCH 15/27] [math] Fix jitconn matvec bugs --- brainpy/_src/math/jitconn/_event_matvec.py | 4 ++-- brainpy/_src/math/jitconn/_matvec.py | 4 ++-- brainpy/_src/math/jitconn/tests/test_matvec.py | 12 +++++++----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index f09e069e8..11f1689ae 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -95,7 +95,7 @@ def event_mv_prob_uniform( if method == 'taichi': return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) elif method == 'brainpylib': - return event_mv_prob_uniform_brainpylib(events, w_low, w_high, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return event_mv_prob_uniform_brainpylib(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) else: raise ValueError(f'Unknown method {method}.') @@ -126,7 +126,7 @@ def event_mv_prob_normal( if method == 'taichi': return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) elif method == 'brainpylib': - return event_mv_prob_uniform_brainpylib(events, w_mu, w_sigma, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return event_mv_prob_uniform_brainpylib(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) else: raise ValueError(f'Unknown method {method}.') diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index d6a41c30c..e2dd0e72d 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -179,7 +179,7 @@ def mv_prob_uniform( if method == 'taichi': return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) elif method == 'brainpylib': - return mv_prob_uniform_brainpylib(vector, w_low, w_high, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return mv_prob_uniform_brainpylib(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) else: raise ValueError(f'Unknown method {method}.') @@ -258,7 +258,7 @@ def mv_prob_normal( if method == 'taichi': return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) elif method == 'brainpylib': - return mv_prob_uniform_brainpylib(vector, w_mu, w_sigma, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return mv_prob_uniform_brainpylib(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) else: raise ValueError(f'Unknown method {method}.') diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 25656f9ab..5176a13a8 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -407,7 +407,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = brainpylib_mv_prob_uniform(events, + r1 = brainpylib_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -416,7 +416,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se outdim_parallel=outdim_parallel, transpose=transpose) - r2 = brainpylib_mv_prob_uniform(events, + r2 = brainpylib_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -429,7 +429,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se print(r1, r2) self.assertTrue(c) - r2 = brainpylib_mv_prob_uniform(events, + r2 = brainpylib_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -476,7 +476,7 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e, + f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -528,7 +528,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x events = events.astype(float) f1 = jax.grad( - lambda e, w_sigma: brainpylib_mv_prob_uniform( + lambda e, w_sigma: brainpylib_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, @@ -541,6 +541,8 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x ) r1 = f1(events, 1.) r2 = f1(events, 2.) + print('r1:', r1) + print('r2:', r2) self.assertTrue(bm.allclose(r1 * 2., r2)) if x64: From 97f7e7ae609d0ff373fcef41e86bbe057d0a7376 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 28 Jan 2024 23:09:12 +0800 Subject: [PATCH 16/27] Update linear.py --- brainpy/_src/dnn/linear.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 62f98f6c6..b635d21f1 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -570,15 +570,17 @@ def __init__( sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, + method: str = None, transpose: bool = True, ): super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + self.method = method def update(self, x): if x.ndim == 1: return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) + method=self.method, transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) @@ -590,7 +592,7 @@ def update(self, x): def _batch_csrmv(self, x): return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) + method=self.method, transpose=self.transpose) class EventCSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with event CSR sparse computation. From 8a517f0e43b951621fe28242b6d0f2cdd305281d Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 13:06:20 +0800 Subject: [PATCH 17/27] [math] Update operators --- brainpy/_src/math/event/_csr_matvec.py | 23 +--------- brainpy/_src/math/jitconn/_event_matvec.py | 51 ++-------------------- brainpy/_src/math/jitconn/_matvec.py | 51 ++-------------------- brainpy/_src/math/sparse/_csr_mv.py | 9 +--- 4 files changed, 8 insertions(+), 126 deletions(-) diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 1a1851adb..c9d3d80ba 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -49,7 +49,6 @@ def csrmv( *, shape: Tuple[int, int], transpose: bool = False, - method: str = None ) -> jax.Array: """Product of a sparse CSR matrix and a dense event vector. @@ -74,14 +73,6 @@ def csrmv( before computing. If ``transpose=True``, the operator will compute based on the event-driven property of the ``events`` vector. - method: str - The method used to compute Matrix-Vector Multiplication. For cpu platform, - the default method is ``taichi``. For gpu platform, the default method is - ``brainpylib``. - The candidate methods are: - - - ``taichi``: using Taichi kernel. - - ``brainpylib``: using brainpylib operators. Returns ------- @@ -89,19 +80,7 @@ def csrmv( The array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ - - if method is None: - if bm.get_platform() == 'cpu': - method = 'taichi' - elif bm.get_platform() == 'gpu': - method = 'brainpylib' - - if method == 'taichi': - return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - elif method == 'brainpylib': - return csrmv_brainpylib(data, indices, indptr, events, shape=shape, transpose=transpose) - else: - raise ValueError(f'Unknown method {method}.') + return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) ### BRAINPYLIB ### diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 11f1689ae..21b082fc5 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -50,23 +50,8 @@ def event_mv_prob_homo( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, - method: str = None, ) -> jax.Array: - if method is None: - if bm.get_platform() == 'cpu': - method = 'taichi' - elif bm.get_platform() == 'gpu': - if outdim_parallel: - method = 'brainpylib' - else: - method = 'taichi' - - if method == 'taichi': - return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - elif method == 'brainpylib': - return event_mv_prob_homo_brainpylib(events, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - else: - raise ValueError(f'Unknown method {method}.') + return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ @@ -81,23 +66,8 @@ def event_mv_prob_uniform( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, - method: str = None, ) -> jax.Array: - if method is None: - if bm.get_platform() == 'cpu': - method = 'taichi' - elif bm.get_platform() == 'gpu': - if outdim_parallel: - method = 'brainpylib' - else: - method = 'taichi' - - if method == 'taichi': - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - elif method == 'brainpylib': - return event_mv_prob_uniform_brainpylib(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - else: - raise ValueError(f'Unknown method {method}.') + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ @@ -112,23 +82,8 @@ def event_mv_prob_normal( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, - method: str = None, ) -> jax.Array: - if method is None: - if bm.get_platform() == 'cpu': - method = 'taichi' - elif bm.get_platform() == 'gpu': - if outdim_parallel: - method = 'brainpylib' - else: - method = 'taichi' - - if method == 'taichi': - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - elif method == 'brainpylib': - return event_mv_prob_uniform_brainpylib(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - else: - raise ValueError(f'Unknown method {method}.') + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) ### BRAINPYLIB ### diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index e2dd0e72d..f20c12f02 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -37,7 +37,6 @@ def mv_prob_homo( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, - method: str = None, ) -> jax.Array: r"""Perform the :math:`y=M@v` operation, where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. @@ -87,21 +86,7 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - if method is None: - if bm.get_platform() == 'cpu': - method = 'taichi' - elif bm.get_platform() == 'gpu': - if outdim_parallel: - method = 'brainpylib' - else: - method = 'taichi' - - if method == 'taichi': - return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - elif method == 'brainpylib': - return mv_prob_homo_brainpylib(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - else: - raise ValueError(f'Unknown method {method}.') + return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -115,7 +100,6 @@ def mv_prob_uniform( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, - method: str = None, ) -> jax.Array: r"""Perform the :math:`y=M@v` operation, where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. @@ -167,21 +151,7 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - if method is None: - if bm.get_platform() == 'cpu': - method = 'taichi' - elif bm.get_platform() == 'gpu': - if outdim_parallel: - method = 'brainpylib' - else: - method = 'taichi' - - if method == 'taichi': - return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - elif method == 'brainpylib': - return mv_prob_uniform_brainpylib(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - else: - raise ValueError(f'Unknown method {method}.') + return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) def mv_prob_normal( @@ -194,7 +164,6 @@ def mv_prob_normal( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, - method: str = None, ) -> jax.Array: r"""Perform the :math:`y=M@v` operation, where :math:`M` is just-in-time randomly generated with a normal distribution for its value. @@ -246,21 +215,7 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - if method is None: - if bm.get_platform() == 'cpu': - method = 'taichi' - elif bm.get_platform() == 'gpu': - if outdim_parallel: - method = 'brainpylib' - else: - method = 'taichi' - - if method == 'taichi': - return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - elif method == 'brainpylib': - return mv_prob_uniform_brainpylib(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - else: - raise ValueError(f'Unknown method {method}.') + return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) ### BRAINYPLIB ### diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index aa29ed36a..acd68b999 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -64,8 +64,7 @@ def csrmv( The method used to compute Matrix-Vector Multiplication. Default is ``taichi``. The candidate methods are: - - ``taichi``: using Taichi kernel. - - ``brainpylib``: using cuSPARSE library. + - ``None``: default using Taichi kernel. - ``cusparse``: using cuSPARSE library. - ``scalar``: - ``vector``: @@ -78,16 +77,10 @@ def csrmv( the matrix vector product. """ if method is None: - method = 'taichi' - - if method == 'taichi': return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) - elif method == 'brainpylib': - return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method='cusparse') else: return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method) - ### BRAINPYLIB ### def csrmv_brainpylib( From 046dbea0bd24034b041855d9207a8d3d39233f2d Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 13:16:23 +0800 Subject: [PATCH 18/27] [math] Update pytests --- .../_src/math/event/tests/test_event_csrmv.py | 261 +++++-------- .../math/event/tests/test_event_csrmv_gpu.py | 15 - .../math/event/tests/test_event_csrmv_old.py | 324 ++++++++++++++++ .../event/tests/test_event_csrmv_taichi.py | 240 ------------ .../math/jitconn/tests/test_event_matvec.py | 197 +++++----- .../jitconn/tests/test_event_matvec_gpu.py | 14 - ...vec_taichi.py => test_event_matvec_old.py} | 197 +++++----- .../_src/math/jitconn/tests/test_matvec.py | 228 ++++++------ .../math/jitconn/tests/test_matvec_gpu.py | 14 - ...st_matvec_taichi.py => test_matvec_old.py} | 213 ++++++----- brainpy/_src/math/sparse/tests/test_csrmv.py | 337 +++++++---------- .../_src/math/sparse/tests/test_csrmv_gpu.py | 21 -- .../_src/math/sparse/tests/test_csrmv_old.py | 352 ++++++++++++++++++ .../math/sparse/tests/test_csrmv_taichi.py | 265 ------------- 14 files changed, 1309 insertions(+), 1369 deletions(-) delete mode 100644 brainpy/_src/math/event/tests/test_event_csrmv_gpu.py create mode 100644 brainpy/_src/math/event/tests/test_event_csrmv_old.py delete mode 100644 brainpy/_src/math/event/tests/test_event_csrmv_taichi.py delete mode 100644 brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py rename brainpy/_src/math/jitconn/tests/{test_event_matvec_taichi.py => test_event_matvec_old.py} (73%) delete mode 100644 brainpy/_src/math/jitconn/tests/test_matvec_gpu.py rename brainpy/_src/math/jitconn/tests/{test_matvec_taichi.py => test_matvec_old.py} (70%) delete mode 100644 brainpy/_src/math/sparse/tests/test_csrmv_gpu.py create mode 100644 brainpy/_src/math/sparse/tests/test_csrmv_old.py delete mode 100644 brainpy/_src/math/sparse/tests/test_csrmv_taichi.py diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 67b77c5a3..6f63b0454 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -8,16 +8,10 @@ import brainpy as bp import brainpy.math as bm -import platform +from .._csr_matvec import csrmv_brainpylib as brainpylib_csr_matvec -import pytest +seed = 1234 -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - -brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') -taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') def sum_op(op): def func(*args, **kwargs): @@ -26,127 +20,91 @@ def func(*args, **kwargs): return func +taichi_csr_matvec = bm.event.csrmv -class Test_event_csr_matvec(parameterized.TestCase): +class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_csr_matvec, self).__init__(*args, **kwargs) - bm.set_platform(platform) + super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs) + print() + bm.set_platform(platform) - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] - for homo_data in [-1., 0., 1.] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)], + homo_data=[-1., 0., 1.], ) - def test_homo(self, shape, transpose, homo_data): + def test_homo(self, transpose, shape, homo_data): print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) + r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r3 = brainpylib_csr_matvec(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r3)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r4)) - - r5 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r5)) + assert (bm.allclose(r1, r2)) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)], + homo_data=[-1., 0., 1.], ) def test_homo_vmap(self, shape, transpose, homo_data): print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) + f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr, + f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose, - method='cusparse')) + f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) + f6(vmap_data1, vmap_data2))) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}', - homo_data=homo_data, - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)], + homo_data=[-1., 0., 1.], ) def test_homo_grad(self, shape, transpose, homo_data): print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -156,44 +114,29 @@ def test_homo_grad(self, shape, transpose, homo_data): # grad 'data' r1 = jax.grad(sum_op(brainpylib_csr_matvec))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(taichi_csr_matvec))( + homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) - r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else - ((dense_conn * a) @ events))))(homo_data) - self.assertTrue(bm.allclose(r1, r3)) # grad 'events' - r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else - ((dense_conn * homo_data) @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) + self.assertTrue(bm.allclose(r3, r4)) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000), ] ) def test_heter(self, shape, transpose): print(f'test_heter: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -202,39 +145,24 @@ def test_heter(self, shape, transpose): r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r3)) + r2 = taichi_csr_matvec(heter_data, indices, indptr, events, + shape=shape, transpose=transpose) - r4 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r4)) + assert (bm.allclose(r1, r2)) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict( - testcase_name=f"transpose={transpose}, shape={shape}", - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)] ) def test_heter_vmap(self, shape, transpose): print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -243,9 +171,8 @@ def test_heter_vmap(self, shape, transpose): events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) + f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -253,41 +180,34 @@ def test_heter_vmap(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr, + f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) + f6(vmap_data1, vmap_data2))) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict(testcase_name=f'transpose={transpose},shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)] ) def test_heter_grad(self, shape, transpose): print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -299,25 +219,22 @@ def test_heter_grad(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) r1 = jax.grad(sum_op(brainpylib_csr_matvec))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(taichi_csr_matvec))( + data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) - dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape) - r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else - (a @ events))))(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r3 = r3[rows, cols] - self.assertTrue(bm.allclose(r1, r3)) - # grad 'events' - r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r3, r4)) + + r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( + r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else - (dense_data @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) + self.assertTrue(bm.allclose(r5[0], r6[0])) + self.assertTrue(bm.allclose(r5[1], r6[1])) bm.clear_buffer_memory() diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py b/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py deleted file mode 100644 index a5b8df152..000000000 --- a/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- - - -import jax -import pytest - -import test_event_csrmv - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_event_csr_matvec_GPU(test_event_csrmv.Test_event_csr_matvec): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_old.py b/brainpy/_src/math/event/tests/test_event_csrmv_old.py new file mode 100644 index 000000000..31a6527a2 --- /dev/null +++ b/brainpy/_src/math/event/tests/test_event_csrmv_old.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- + + +from functools import partial + +import jax +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +import platform + +import pytest +pytest.skip('Old implementation.', allow_module_level=True) + +is_manual_test = False +# if platform.system() == 'Windows' and not is_manual_test: +# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) + +brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') +taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +class Test_event_csr_matvec(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_event_csr_matvec, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.named_parameters( + dict( + testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', + transpose=transpose, + shape=shape, + homo_data=homo_data, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (10000, 2)] + for homo_data in [-1., 0., 1.] + ) + def test_homo(self, shape, transpose, homo_data): + print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + heter_data = bm.ones(indices.shape) * homo_data + + r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + r3 = brainpylib_csr_matvec(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r3)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r4 = (events @ dense) if transpose else (dense @ events) + self.assertTrue(bm.allclose(r1, r4)) + + r5 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r5)) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', + transpose=transpose, + shape=shape, + homo_data=homo_data, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (100000, 2)] + for homo_data in [-1., 0., 1.] + ) + def test_homo_vmap(self, shape, transpose, homo_data): + print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + + # vmap 'data' + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, + shape=shape, transpose=transpose)) + f2 = jax.vmap( + partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), + shape=shape, transpose=transpose)) + vmap_data = bm.as_jax([homo_data] * 10) + self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) + + # vmap 'events' + f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, + shape=shape, transpose=transpose)) + f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr, + shape=shape, transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) + + # vmap 'data' and 'events' + f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose, + method='cusparse')) + vmap_data1 = bm.as_jax([homo_data] * 10) + vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 + self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), + f6(vmap_data1, vmap_data2.astype(float)))) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}', + homo_data=homo_data, + shape=shape, + transpose=transpose, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (100000, 2)] + for homo_data in [-1., 0., 1.] + ) + def test_homo_grad(self, shape, transpose, homo_data): + print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) + + # grad 'data' + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( + homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else + ((dense_conn * a) @ events))))(homo_data) + self.assertTrue(bm.allclose(r1, r3)) + + # grad 'events' + r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( + homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else + ((dense_conn * homo_data) @ e))))(events.astype(float)) + self.assertTrue(bm.allclose(r4, r5)) + self.assertTrue(bm.allclose(r4, r6)) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=f'transpose={transpose}, shape={shape}', + shape=shape, + transpose=transpose, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (10000, 2)] + ) + def test_heter(self, shape, transpose): + print(f'test_heter: shape = {shape}, transpose = {transpose}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + heter_data = bm.as_jax(rng.random(indices.shape)) + + r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, + shape=shape, transpose=transpose) + r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float), + shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r3 = (events @ dense) if transpose else (dense @ events) + self.assertTrue(bm.allclose(r1, r3)) + + r4 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), + shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r4)) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=f"transpose={transpose}, shape={shape}", + shape=shape, + transpose=transpose, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (100000, 2)] + ) + def test_heter_vmap(self, shape, transpose): + print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + + # vmap 'data' + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, + shape=shape, transpose=transpose)) + f2 = jax.vmap( + partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), + shape=shape, transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) + self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) + + # vmap 'events' + data = bm.as_jax(rng.random(indices.shape)) + f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, + shape=shape, transpose=transpose)) + f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr, + shape=shape, transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) + + # vmap 'data' and 'events' + f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) + vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) + vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 + self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), + f6(vmap_data1, vmap_data2.astype(float)))) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'transpose={transpose},shape={shape}', + shape=shape, + transpose=transpose, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (100000, 2)] + ) + def test_heter_grad(self, shape, transpose): + print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) + + # grad 'data' + data = bm.as_jax(rng.random(indices.shape)) + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape) + r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else + (a @ events))))(dense_data) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + r3 = r3[rows, cols] + self.assertTrue(bm.allclose(r1, r3)) + + # grad 'events' + r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else + (dense_data @ e))))(events.astype(float)) + self.assertTrue(bm.allclose(r4, r5)) + self.assertTrue(bm.allclose(r4, r6)) + + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py deleted file mode 100644 index 781e3c91c..000000000 --- a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py +++ /dev/null @@ -1,240 +0,0 @@ -# -*- coding: utf-8 -*- - - -from functools import partial - -import jax -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -seed = 1234 - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - -brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') -taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') - -class Test_event_csr_matvec_taichi(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - heter_data = bm.ones(indices.shape) * homo_data - - r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - - assert (bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], - ) - def test_homo_vmap(self, shape, transpose, homo_data): - print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - - vmap_data1 = bm.as_jax([homo_data] * 10) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2))) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], - ) - def test_homo_grad(self, shape, transpose, homo_data): - print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'events' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000), ] - ) - def test_heter(self, shape, transpose): - print(f'test_heter: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - heter_data = bm.as_jax(rng.random(indices.shape)) - - r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - r2 = taichi_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - - assert (bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] - ) - def test_heter_vmap(self, shape, transpose): - print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2))) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] - ) - def test_heter_grad(self, shape, transpose): - print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'events' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r3, r4)) - - r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 24f66878e..62b665f47 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -5,20 +5,10 @@ import jax.numpy as jnp from absl.testing import parameterized -import platform import brainpy.math as bm -import pytest - -is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True) - -shapes = [(100, 200), - # (10, 1000), - (2, 1000), - # (1000, 10), - (1000, 2)] +shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] +shapes = [(100, 200), (2, 1000), (1000, 2)] brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib') taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi') @@ -60,32 +50,32 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve if not bool_event: events = events.astype(float) - r1 = brainpylib_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = taichi_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = brainpylib_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = taichi_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = brainpylib_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = taichi_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) @@ -127,10 +117,10 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=Tru weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: brainpylib_mv_prob_homo( + lambda event, data: taichi_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, transpose=transpose, outdim_parallel=outdim_parallel - ) + )[0] ) r1 = f1(events, weights) r1 = jax.block_until_ready(r1) @@ -171,10 +161,9 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: brainpylib_mv_prob_homo( + lambda event, data: taichi_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - ).sum(), + outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), argnums=0 ) r1 = f1(events, 1.) @@ -238,35 +227,35 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, if not bool_event: events = events.astype(float) - r1 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) if x64: @@ -309,14 +298,14 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, events = events.astype(float) f1 = jax.vmap( - lambda e: brainpylib_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + lambda e: taichi_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) ) r1 = f1(events) @@ -359,7 +348,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = events.astype(float) f1 = jax.grad( - lambda e, w_high: brainpylib_mv_prob_uniform( + lambda e, w_high: taichi_mv_prob_uniform( e, w_low=0., w_high=w_high, @@ -420,35 +409,35 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, if not bool_event: events = events.astype(float) - r1 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) @@ -493,14 +482,14 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, if not bool_event: events = events.astype(float) - f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r1 = jax.block_until_ready(r1) r2 = f1(events) @@ -543,7 +532,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x f1 = jax.jit( jax.grad( - lambda e, w_sigma: brainpylib_mv_prob_normal( + lambda e, w_sigma: taichi_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py deleted file mode 100644 index 778212547..000000000 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_event_matvec - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_event_matvec_prob_conn_GPU(test_event_matvec.Test_event_matvec_prob_conn): - def __init__(self, *args, **kwargs): - super(Test_event_matvec_prob_conn_GPU, self).__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py similarity index 73% rename from brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py rename to brainpy/_src/math/jitconn/tests/test_event_matvec_old.py index 62b665f47..b2fa77229 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py @@ -5,10 +5,20 @@ import jax.numpy as jnp from absl.testing import parameterized +import platform import brainpy.math as bm -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] +import pytest +pytest.skip('Old implementation.', allow_module_level=True) +is_manual_test = False +if platform.system() == 'Windows' and not is_manual_test: + pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True) + +shapes = [(100, 200), + # (10, 1000), + (2, 1000), + # (1000, 10), + (1000, 2)] brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib') taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi') @@ -50,32 +60,32 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve if not bool_event: events = events.astype(float) - r1 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = brainpylib_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = brainpylib_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) @@ -117,10 +127,10 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=Tru weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, transpose=transpose, outdim_parallel=outdim_parallel - )[0] + ) ) r1 = f1(events, weights) r1 = jax.block_until_ready(r1) @@ -161,9 +171,10 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), + outdim_parallel=outdim_parallel, transpose=transpose + ).sum(), argnums=0 ) r1 = f1(events, 1.) @@ -227,35 +238,35 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, if not bool_event: events = events.astype(float) - r1 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) if x64: @@ -298,14 +309,14 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, events = events.astype(float) f1 = jax.vmap( - lambda e: taichi_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + lambda e: brainpylib_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) ) r1 = f1(events) @@ -348,7 +359,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = events.astype(float) f1 = jax.grad( - lambda e, w_high: taichi_mv_prob_uniform( + lambda e, w_high: brainpylib_mv_prob_uniform( e, w_low=0., w_high=w_high, @@ -409,35 +420,35 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, if not bool_event: events = events.astype(float) - r1 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) @@ -482,14 +493,14 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, if not bool_event: events = events.astype(float) - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r1 = jax.block_until_ready(r1) r2 = f1(events) @@ -532,7 +543,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x f1 = jax.jit( jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( + lambda e, w_sigma: brainpylib_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 5176a13a8..c857e2e2e 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -4,27 +4,18 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +from .._matvec import (mv_prob_homo_brainpylib as brainpylib_mv_prob_homo, + mv_prob_uniform_brainpylib as brainpylib_mv_prob_uniform, + mv_prob_normal_brainpylib as brainpylib_mv_prob_normal,) import brainpy.math as bm -import platform -import pytest - -is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - -shapes = [(100, 200), - (10, 1000), - (2, 1000), - (1000, 10), - (1000, 2)] - -brainpylib_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='brainpylib') -taichi_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='taichi') -brainpylib_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='brainpylib') -taichi_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='taichi') -brainpylib_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='brainpylib') -taichi_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='taichi') + +shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] +shapes = [(100, 200), (2, 1000), (1000, 2)] + +taichi_mv_prob_homo = bm.jitconn.mv_prob_homo +taichi_mv_prob_uniform = bm.jitconn.mv_prob_uniform +taichi_mv_prob_normal = bm.jitconn.mv_prob_normal class Test_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -66,34 +57,32 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = brainpylib_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = brainpylib_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = taichi_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = taichi_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) self.assertTrue(jnp.allclose(r1, r2)) - r2 = brainpylib_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = taichi_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() bm.clear_buffer_memory() @parameterized.named_parameters( @@ -128,11 +117,11 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64 weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: brainpylib_mv_prob_homo( + lambda event, data: taichi_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - ) + )[0] ) r1 = f1(events, weights) r2 = f1(events, weights) @@ -173,14 +162,14 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: brainpylib_mv_prob_homo( + lambda event, data: taichi_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - ).sum(), + )[0].sum(), argnums=0 ) r1 = f1(events, 1.) @@ -230,36 +219,36 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) self.assertTrue(c) - r2 = brainpylib_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) @@ -298,14 +287,14 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r2 = f1(events) @@ -347,7 +336,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) f1 = jax.grad( - lambda e, w_low, w_high: brainpylib_mv_prob_uniform( + lambda e, w_low, w_high: taichi_mv_prob_uniform( e, w_low=w_low, w_high=w_high, @@ -356,7 +345,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - ).sum() + )[0].sum() ) r1 = f1(events, 0., 1.) @@ -407,36 +396,36 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) self.assertTrue(c) - r2 = brainpylib_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) @@ -476,19 +465,20 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r2 = f1(events) - c = jnp.allclose(r1, r2) + c = jnp.allclose(r1, r2, atol=1e-6) if not c: print(r1, r2) + print(r1 - r2) self.assertTrue(c) if x64: @@ -528,7 +518,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x events = events.astype(float) f1 = jax.grad( - lambda e, w_sigma: brainpylib_mv_prob_normal( + lambda e, w_sigma: taichi_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, @@ -537,12 +527,10 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - ).sum() + )[0].sum() ) r1 = f1(events, 1.) r2 = f1(events, 2.) - print('r1:', r1) - print('r2:', r2) self.assertTrue(bm.allclose(r1 * 2., r2)) if x64: diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py b/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py deleted file mode 100644 index f227c0e6a..000000000 --- a/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_matvec - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_matvec_prob_conn_GPU(test_matvec.Test_matvec_prob_conn): - def __init__(self, *args, **kwargs): - super(Test_matvec_prob_conn_GPU, self).__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_matvec_old.py similarity index 70% rename from brainpy/_src/math/jitconn/tests/test_matvec_taichi.py rename to brainpy/_src/math/jitconn/tests/test_matvec_old.py index 8f42831d5..360711e7b 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec_old.py @@ -6,9 +6,19 @@ from absl.testing import parameterized import brainpy.math as bm +import platform +import pytest -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] +pytest.skip('Old implementation.', allow_module_level=True) +is_manual_test = False +if platform.system() == 'Windows' and not is_manual_test: + pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) + +shapes = [(100, 200), + (10, 1000), + (2, 1000), + (1000, 10), + (1000, 2)] brainpylib_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='brainpylib') taichi_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='taichi') @@ -57,32 +67,34 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = taichi_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = taichi_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = brainpylib_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) self.assertTrue(jnp.allclose(r1, r2)) - r2 = taichi_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = brainpylib_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) self.assertTrue(jnp.allclose(r1, r2)) + if x64: + bm.disable_x64() bm.clear_buffer_memory() @parameterized.named_parameters( @@ -117,11 +129,11 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64 weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - )[0] + ) ) r1 = f1(events, weights) r2 = f1(events, weights) @@ -162,14 +174,14 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - )[0].sum(), + ).sum(), argnums=0 ) r1 = f1(events, 1.) @@ -219,36 +231,36 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) self.assertTrue(c) - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) @@ -287,14 +299,14 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r2 = f1(events) @@ -336,7 +348,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) f1 = jax.grad( - lambda e, w_low, w_high: taichi_mv_prob_uniform( + lambda e, w_low, w_high: brainpylib_mv_prob_uniform( e, w_low=w_low, w_high=w_high, @@ -345,7 +357,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - )[0].sum() + ).sum() ) r1 = f1(events, 0., 1.) @@ -396,36 +408,36 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) self.assertTrue(c) - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) @@ -465,20 +477,19 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r2 = f1(events) - c = jnp.allclose(r1, r2, atol=1e-6) + c = jnp.allclose(r1, r2) if not c: print(r1, r2) - print(r1 - r2) self.assertTrue(c) if x64: @@ -518,7 +529,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x events = events.astype(float) f1 = jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( + lambda e, w_sigma: brainpylib_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, @@ -527,10 +538,12 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - )[0].sum() + ).sum() ) r1 = f1(events, 1.) r2 = f1(events, 2.) + print('r1:', r1) + print('r2:', r2) self.assertTrue(bm.allclose(r1 * 2., r2)) if x64: diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index b902cd8a0..123ca657e 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -3,24 +3,61 @@ from functools import partial import jax -import pytest from absl.testing import parameterized -import platform + import brainpy as bp import brainpy.math as bm +from .._csr_mv import csrmv_brainpylib as brainpylib_csr_matvec + +# bm.set_platform('gpu') + +seed = 1234 + + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + + +def compare_with_nan_tolerance(a, b, tol=1e-8): + """ + Compare two arrays with tolerance for NaN values. + + Parameters: + a (np.array): First array to compare. + b (np.array): Second array to compare. + tol (float): Tolerance for comparing non-NaN elements. + + Returns: + bool: True if arrays are similar within the tolerance, False otherwise. + """ + if a.shape != b.shape: + return False + + # Create masks for NaNs in both arrays + nan_mask_a = bm.isnan(a) + nan_mask_b = bm.isnan(b) -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) + # Check if NaN positions are the same in both arrays + if not bm.array_equal(nan_mask_a, nan_mask_b): + return False -cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') -scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') -vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') + # Compare non-NaN elements + a_non_nan = a[~nan_mask_a] + b_non_nan = b[~nan_mask_b] + return bm.allclose(a_non_nan, b_non_nan, atol=tol) -class Test_cusparse_csrmv(parameterized.TestCase): + +taichi_csr_matvec = bm.sparse.csrmv + +class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): - super(Test_cusparse_csrmv, self).__init__(*args, **kwargs) + super(Test_csrmv_taichi, self).__init__(*args, **kwargs) print() bm.set_platform(platform) @@ -31,35 +68,33 @@ def __init__(self, *args, platform='cpu', **kwargs): homo_data=[-1., 0., 1.] ) def test_homo(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') + conn = bp.conn.FixedProb(0.3) + # matrix indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) - - heter_data = bm.ones(indices.shape).value * homo_data - + # vector + rng = bm.random.RandomState(seed=seed) vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r3)) + r1 = brainpylib_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)], v=[-1., 0., 1.] ) def test_homo_vmap(self, transpose, shape, v): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) @@ -71,17 +106,14 @@ def test_homo_vmap(self, transpose, shape, v): homo_data = bm.ones(10).value * v dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f1 = partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, vector=vector, + shape=shape, transpose=transpose) + f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(heter_data) + r2 = jax.vmap(f1)(homo_data) self.assertTrue(bm.allclose(r1, r2)) - r3 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r3)) - bm.clear_buffer_memory() @parameterized.product( @@ -90,8 +122,9 @@ def test_homo_vmap(self, transpose, shape, v): homo_data=[-1., 0., 1.] ) def test_homo_grad(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) @@ -103,37 +136,36 @@ def test_homo_grad(self, transpose, shape, homo_data): vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, transpose=transpose).sum(), - argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() - if transpose else - ((dense * a) @ vector).sum()), - argnums=0) - - r1 = csr_f1(homo_data) - r2 = dense_f1(homo_data) + # print('grad data start') + # grad 'data' + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(taichi_csr_matvec))( + homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + + # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, + # shape=shape, transpose=transpose).sum(), + # argnums=0) + # csr_f2 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, + # shape=shape, transpose=transpose)[0].sum(), + # argnums=0) + # r1 = csr_f1(homo_data) + # r2 = csr_f2(homo_data) self.assertTrue(bm.allclose(r1, r2)) - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v, - shape=shape, transpose=transpose).sum()) - dense_data = dense * homo_data - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) + # print('grad vector start') + # grad 'vector' + r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r3 = csr_f2(vector) - r4 = dense_f2(vector) self.assertTrue(bm.allclose(r3, r4)) - csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v, - shape=shape, transpose=transpose).sum(), - argnums=(0, 1)) - dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() - if transpose else - ((dense * a) @ v).sum()), - argnums=(0, 1)) - - r5 = csr_f3(homo_data, vector) - r6 = dense_f3(homo_data, vector) + r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( + homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( + homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) @@ -141,26 +173,27 @@ def test_homo_grad(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + shape=[(200, 200), (200, 100), (2, 2000)], ) def test_heter(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + print(f'test_homo: transpose = {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) + heter_data = bm.as_jax(rng.random(indices.shape)) heter_data = bm.as_jax(heter_data) vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector, - shape=shape, transpose=transpose) - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r2 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r2)) + + r1 = brainpylib_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + + self.assertTrue(compare_with_nan_tolerance(r1, r2)) bm.clear_buffer_memory() @@ -169,8 +202,8 @@ def test_heter(self, transpose, shape): shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] ) def test_heter_vmap(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) @@ -183,23 +216,21 @@ def test_heter_vmap(self, transpose, shape): dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f1 = partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, vector=vector, + shape=shape, transpose=transpose) + f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - r1 = jax.vmap(f1)(heter_data) - r2 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() + r2 = jax.vmap(f2)(heter_data) + self.assertTrue(compare_with_nan_tolerance(r1, r2)) @parameterized.product( transpose=[True, False], shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] ) def test_heter_grad(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) @@ -210,141 +241,25 @@ def test_heter_grad(self, transpose, shape): vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), - argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), - argnums=0) - - r1 = csr_f1(heter_data) - r2 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r2 = r2[rows, cols] + # grad 'data' + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(taichi_csr_matvec))( + heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), - argnums=0) - r3 = csr_f2(vector) - r4 = dense_f2(vector) + # grad 'vector' + r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - bm.clear_buffer_memory() - - -class Test_csrmv(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmv, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - homo_data=[-1., 0., 0.1, 1.], - shape=[(100, 200), (10, 1000), (2, 2000)], - ) - def test_homo(self, shape, homo_data): - conn = bp.conn.FixedProb(0.1) - - # matrix - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # vector - rng = bm.random.RandomState(123) - vector = rng.random(shape[1]) - vector = bm.as_jax(vector) - - # csrmv - r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - - heter_data = bm.ones(indices.shape).to_jax() * homo_data - r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r4)) - self.assertTrue(bm.allclose(r1, r5)) - self.assertTrue(bm.allclose(r1, r6)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - rdense = dense @ vector - self.assertTrue(bm.allclose(r1, rdense)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = bm.as_jax(rng.random(indices.shape)) - vector = bm.as_jax(rng.random(shape[1])) - - r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = dense @ vector - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - heter_data = rng.random(indices.shape) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[1]) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - dense_f1 = jax.grad(lambda a: (a @ vector).sum()) - - r1 = csr_f1(heter_data) - r2 = csr_f2(heter_data) - r3 = csr_f3(heter_data) - - d1 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - d1 = d1[rows, cols] - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, d1)) - - # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum()) - # r4 = csr_f4(vector) - # r5 = csr_f5(vector) - # r6 = csr_f6(vector) - # d2 = dense_f2(vector) - # self.assertTrue(bm.allclose(r4, r5)) - # self.assertTrue(bm.allclose(r4, r6)) - # self.assertTrue(bm.allclose(r4, d2)) + r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( + heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( + heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r5[0], r6[0])) + self.assertTrue(bm.allclose(r5[1], r6[1])) bm.clear_buffer_memory() - - diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py b/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py deleted file mode 100644 index ccf090ec4..000000000 --- a/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py +++ /dev/null @@ -1,21 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_csrmv - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_cusparse_csrmv_GPU(test_csrmv.Test_cusparse_csrmv): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, platform='gpu') - - -class Test__csrmv_GPU(test_csrmv.Test_csrmv): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, platform='gpu') - - diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_old.py b/brainpy/_src/math/sparse/tests/test_csrmv_old.py new file mode 100644 index 000000000..b73217496 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/test_csrmv_old.py @@ -0,0 +1,352 @@ +# -*- coding: utf-8 -*- + +from functools import partial + +import jax +import pytest +from absl.testing import parameterized +import platform +import brainpy as bp +import brainpy.math as bm + +pytest.skip('Old implementation.', allow_module_level=True) + +is_manual_test = False +# if platform.system() == 'Windows' and not is_manual_test: +# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) + +cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') +scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') +vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') + + +class Test_cusparse_csrmv(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_cusparse_csrmv, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + homo_data=[-1., 0., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + + heter_data = bm.ones(indices.shape).value * homo_data + + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r3 = (vector @ dense) if transpose else (dense @ vector) + self.assertTrue(bm.allclose(r1, r3)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + v=[-1., 0., 1.] + ) + def test_homo_vmap(self, transpose, shape, v): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + heter_data = bm.ones((10, indices.shape[0])).value * v + homo_data = bm.ones(10).value * v + dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) + + f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, + shape=shape, transpose=transpose) + f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) + + r1 = jax.vmap(f1)(homo_data) + r2 = jax.vmap(f1)(heter_data) + self.assertTrue(bm.allclose(r1, r2)) + + r3 = jax.vmap(f2)(dense_data) + self.assertTrue(bm.allclose(r1, r3)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + homo_data=[-1., 0., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=shape) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, + shape=shape, transpose=transpose).sum(), + argnums=0) + dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() + if transpose else + ((dense * a) @ vector).sum()), + argnums=0) + + r1 = csr_f1(homo_data) + r2 = dense_f1(homo_data) + self.assertTrue(bm.allclose(r1, r2)) + + csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v, + shape=shape, transpose=transpose).sum()) + dense_data = dense * homo_data + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) + + r3 = csr_f2(vector) + r4 = dense_f2(vector) + self.assertTrue(bm.allclose(r3, r4)) + + csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v, + shape=shape, transpose=transpose).sum(), + argnums=(0, 1)) + dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() + if transpose else + ((dense * a) @ v).sum()), + argnums=(0, 1)) + + r5 = csr_f3(homo_data, vector) + r6 = dense_f3(homo_data, vector) + self.assertTrue(bm.allclose(r5[0], r6[0])) + self.assertTrue(bm.allclose(r5[1], r6[1])) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + ) + def test_heter(self, transpose, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + + heter_data = rng.random(indices.shape) + heter_data = bm.as_jax(heter_data) + + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector, + shape=shape, transpose=transpose) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r2 = (vector @ dense) if transpose else (dense @ vector) + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + ) + def test_heter_vmap(self, transpose, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + heter_data = rng.random((10, indices.shape[0])) + heter_data = bm.as_jax(heter_data) + dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, + shape=shape))(heter_data) + + f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, + shape=shape, transpose=transpose) + f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) + + r1 = jax.vmap(f1)(heter_data) + r2 = jax.vmap(f2)(dense_data) + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + ) + def test_heter_grad(self, transpose, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + heter_data = rng.random(indices.shape) + heter_data = bm.as_jax(heter_data) + dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, + shape=shape, + transpose=transpose).sum(), + argnums=0) + dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), + argnums=0) + + r1 = csr_f1(heter_data) + r2 = dense_f1(dense_data) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + r2 = r2[rows, cols] + self.assertTrue(bm.allclose(r1, r2)) + + csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, + shape=shape, + transpose=transpose).sum(), + argnums=0) + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), + argnums=0) + r3 = csr_f2(vector) + r4 = dense_f2(vector) + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + +class Test_csrmv(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_csrmv, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + homo_data=[-1., 0., 0.1, 1.], + shape=[(100, 200), (10, 1000), (2, 2000)], + ) + def test_homo(self, shape, homo_data): + conn = bp.conn.FixedProb(0.1) + + # matrix + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # vector + rng = bm.random.RandomState(123) + vector = rng.random(shape[1]) + vector = bm.as_jax(vector) + + # csrmv + r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape) + r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape) + r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape) + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, r3)) + + heter_data = bm.ones(indices.shape).to_jax() * homo_data + r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + self.assertTrue(bm.allclose(r1, r4)) + self.assertTrue(bm.allclose(r1, r5)) + self.assertTrue(bm.allclose(r1, r6)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + rdense = dense @ vector + self.assertTrue(bm.allclose(r1, rdense)) + + bm.clear_buffer_memory() + + @parameterized.product( + shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)] + ) + def test_heter(self, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + heter_data = bm.as_jax(rng.random(indices.shape)) + vector = bm.as_jax(rng.random(shape[1])) + + r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r4 = dense @ vector + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, r3)) + self.assertTrue(bm.allclose(r1, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + ) + def test_heter_grad(self, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + heter_data = rng.random(indices.shape) + dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + vector = rng.random(shape[1]) + + csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) + csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) + csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) + dense_f1 = jax.grad(lambda a: (a @ vector).sum()) + + r1 = csr_f1(heter_data) + r2 = csr_f2(heter_data) + r3 = csr_f3(heter_data) + + d1 = dense_f1(dense_data) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + d1 = d1[rows, cols] + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, r3)) + self.assertTrue(bm.allclose(r1, d1)) + + # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) + # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) + # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) + # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum()) + # r4 = csr_f4(vector) + # r5 = csr_f5(vector) + # r6 = csr_f6(vector) + # d2 = dense_f2(vector) + # self.assertTrue(bm.allclose(r4, r5)) + # self.assertTrue(bm.allclose(r4, r6)) + # self.assertTrue(bm.allclose(r4, d2)) + + bm.clear_buffer_memory() + + diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py deleted file mode 100644 index 5e2a644a4..000000000 --- a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py +++ /dev/null @@ -1,265 +0,0 @@ -# -*- coding: utf-8 -*- - -from functools import partial - -import jax -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -# bm.set_platform('gpu') - -seed = 1234 - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - - -def compare_with_nan_tolerance(a, b, tol=1e-8): - """ - Compare two arrays with tolerance for NaN values. - - Parameters: - a (np.array): First array to compare. - b (np.array): Second array to compare. - tol (float): Tolerance for comparing non-NaN elements. - - Returns: - bool: True if arrays are similar within the tolerance, False otherwise. - """ - if a.shape != b.shape: - return False - - # Create masks for NaNs in both arrays - nan_mask_a = bm.isnan(a) - nan_mask_b = bm.isnan(b) - - # Check if NaN positions are the same in both arrays - if not bm.array_equal(nan_mask_a, nan_mask_b): - return False - - # Compare non-NaN elements - a_non_nan = a[~nan_mask_a] - b_non_nan = b[~nan_mask_b] - - return bm.allclose(a_non_nan, b_non_nan, atol=tol) - - -vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') -taichi_csr_matvec = partial(bm.sparse.csrmv, method='taichi') - -class Test_csrmv_taichi(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmv_taichi, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') - conn = bp.conn.FixedProb(0.3) - - # matrix - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # vector - rng = bm.random.RandomState(seed=seed) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)], - v=[-1., 0., 1.] - ) - def test_homo_vmap(self, transpose, shape, v): - print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = bm.ones((10, indices.shape[0])).value * v - homo_data = bm.ones(10).value * v - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - - f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(homo_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - # print('grad data start') - # grad 'data' - r1 = jax.grad(sum_op(vector_csr_matvec))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - - # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, - # shape=shape, transpose=transpose).sum(), - # argnums=0) - # csr_f2 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, - # shape=shape, transpose=transpose)[0].sum(), - # argnums=0) - # r1 = csr_f1(homo_data) - # r2 = csr_f2(homo_data) - self.assertTrue(bm.allclose(r1, r2)) - - # print('grad vector start') - # grad 'vector' - r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (2, 2000)], - ) - def test_heter(self, transpose, shape): - print(f'test_homo: transpose = {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = bm.as_jax(rng.random(indices.shape)) - heter_data = bm.as_jax(heter_data) - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - - self.assertTrue(compare_with_nan_tolerance(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_vmap(self, transpose, shape): - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = rng.random((10, indices.shape[0])) - heter_data = bm.as_jax(heter_data) - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=shape))(heter_data) - - f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - r1 = jax.vmap(f1)(heter_data) - r2 = jax.vmap(f2)(heter_data) - self.assertTrue(compare_with_nan_tolerance(r1, r2)) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, transpose, shape): - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - # grad 'data' - r1 = jax.grad(sum_op(vector_csr_matvec))( - heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'vector' - r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r3, r4)) - - r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() From df8c0bfa2f243321fac2d771da91c7de702b6515 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 13:35:14 +0800 Subject: [PATCH 19/27] [math] Fix pytest bugs --- .../_src/math/event/tests/test_event_csrmv.py | 28 +++++----- brainpy/_src/math/sparse/tests/test_csrmv.py | 53 ++++++++++++------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 6f63b0454..0598734a7 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -8,7 +8,6 @@ import brainpy as bp import brainpy.math as bm -from .._csr_matvec import csrmv_brainpylib as brainpylib_csr_matvec seed = 1234 @@ -44,7 +43,8 @@ def test_homo(self, transpose, shape, homo_data): events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data - r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r1 = (events @ dense) if transpose else (dense @ events) r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -67,7 +67,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) @@ -75,7 +75,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' - f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, + f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, shape=shape, transpose=transpose)) @@ -83,7 +83,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax([homo_data] * 10) @@ -112,14 +112,14 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + r1 = jax.grad(sum_op(bm.sparse.csrmv))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) r2 = jax.grad(sum_op(taichi_csr_matvec))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) @@ -143,7 +143,7 @@ def test_heter(self, shape, transpose): events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 heter_data = bm.as_jax(rng.random(indices.shape)) - r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, + r1 = bm.sparse.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) r2 = taichi_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) @@ -169,7 +169,7 @@ def test_heter_vmap(self, shape, transpose): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) @@ -178,7 +178,7 @@ def test_heter_vmap(self, shape, transpose): # vmap 'events' data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, + f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, shape=shape, transpose=transpose)) @@ -186,7 +186,7 @@ def test_heter_vmap(self, shape, transpose): self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, + f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) @@ -217,20 +217,20 @@ def test_heter_grad(self, shape, transpose): # grad 'data' data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + r1 = jax.grad(sum_op(bm.sparse.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) r2 = jax.grad(sum_op(taichi_csr_matvec))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( + r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 123ca657e..ec3ea3c5a 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -7,7 +7,6 @@ import brainpy as bp import brainpy.math as bm -from .._csr_mv import csrmv_brainpylib as brainpylib_csr_matvec # bm.set_platform('gpu') @@ -80,7 +79,10 @@ def test_homo(self, transpose, shape, homo_data): vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - r1 = brainpylib_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + heter_data = bm.ones(indices.shape).value * homo_data + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r1 = (vector @ dense) if transpose else (dense @ vector) r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -106,12 +108,11 @@ def test_homo_vmap(self, transpose, shape, v): homo_data = bm.ones(10).value * v dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - f1 = partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) + f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(homo_data) + r2 = jax.vmap(f2)(homo_data) self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -138,8 +139,11 @@ def test_homo_grad(self, transpose, shape, homo_data): # print('grad data start') # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() + if transpose else + ((dense * a) @ vector).sum()), + argnums=0) + r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(taichi_csr_matvec))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) @@ -155,15 +159,19 @@ def test_homo_grad(self, transpose, shape, homo_data): # print('grad vector start') # grad 'vector' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + dense_data = dense * homo_data + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) + r3 = dense_f2(vector) r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() + if transpose else + ((dense * a) @ v).sum()), + argnums=(0, 1)) + r5 = dense_f3(homo_data, vector) r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) @@ -190,7 +198,8 @@ def test_heter(self, transpose, shape): vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - r1 = brainpylib_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r1 = (vector @ dense) if transpose else (dense @ vector) r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape) self.assertTrue(compare_with_nan_tolerance(r1, r2)) @@ -216,8 +225,7 @@ def test_heter_vmap(self, transpose, shape): dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - f1 = partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) + f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(heter_data) @@ -242,21 +250,26 @@ def test_heter_grad(self, transpose, shape): vector = bm.as_jax(vector) # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), + argnums=0) + r1 = dense_f1(dense_data) r2 = jax.grad(sum_op(taichi_csr_matvec))( heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), + argnums=0) + r3 = dense_f2(vector) r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() + if transpose else + ((dense * a) @ v).sum()), + argnums=(0, 1)) + r5 = dense_f3(heter_data, vector) r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) From c0d756154b593948439829c1c071717c83085605 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 13:43:17 +0800 Subject: [PATCH 20/27] Update test_csrmv.py --- brainpy/_src/math/sparse/tests/test_csrmv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index ec3ea3c5a..042111fb6 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -265,9 +265,9 @@ def test_heter_grad(self, transpose, shape): heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() + dense_f3 = jax.grad(lambda a, v: ((v @ (dense_data * a)).sum() if transpose else - ((dense * a) @ v).sum()), + ((dense_data * a) @ v).sum()), argnums=(0, 1)) r5 = dense_f3(heter_data, vector) r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( From b49ae90c1740e0fc853899ed7f89e45612b9b59f Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 13:46:06 +0800 Subject: [PATCH 21/27] Update test_matvec.py --- brainpy/_src/math/jitconn/tests/test_matvec.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index c857e2e2e..af91c2fff 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -4,9 +4,6 @@ import jax import jax.numpy as jnp from absl.testing import parameterized -from .._matvec import (mv_prob_homo_brainpylib as brainpylib_mv_prob_homo, - mv_prob_uniform_brainpylib as brainpylib_mv_prob_uniform, - mv_prob_normal_brainpylib as brainpylib_mv_prob_normal,) import brainpy.math as bm From 21b8426aaca36f94f88453406cec0313ec772664 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 14:49:21 +0800 Subject: [PATCH 22/27] Update test_event_matvec.py --- brainpy/_src/math/jitconn/tests/test_event_matvec.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 62b665f47..0ef8947bf 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -10,12 +10,9 @@ shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] -brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib') -taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi') -brainpylib_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='brainpylib') -taichi_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='taichi') -brainpylib_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='brainpylib') -taichi_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='taichi') +taichi_mv_prob_homo = bm.jitconn.event_mv_prob_homo +taichi_mv_prob_uniform = bm.jitconn.event_mv_prob_uniform +taichi_mv_prob_normal = bm.jitconn.event_mv_prob_normal class Test_event_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): From c43bead1174908c69ca1b204906d72c251916c08 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 15:13:58 +0800 Subject: [PATCH 23/27] Update test_event_csrmv.py --- brainpy/_src/math/event/tests/test_event_csrmv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 0598734a7..19f7d8685 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -169,7 +169,7 @@ def test_heter_vmap(self, shape, transpose): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) From efc8923dce1211ed0218f1ccc77c057e370a61af Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 18:10:17 +0800 Subject: [PATCH 24/27] [math] Update pytests --- .../_src/math/event/tests/test_event_csrmv.py | 2 +- brainpy/_src/math/sparse/tests/test_csrmv.py | 43 ++++++++----------- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 19f7d8685..e0f38490f 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -67,7 +67,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 042111fb6..2c75f0901 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -111,7 +111,7 @@ def test_homo_vmap(self, transpose, shape, v): f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) - r1 = jax.vmap(f1)(homo_data) + r1 = jax.vmap(f1)(dense_data) r2 = jax.vmap(f2)(homo_data) self.assertTrue(bm.allclose(r1, r2)) @@ -147,14 +147,6 @@ def test_homo_grad(self, transpose, shape, homo_data): r2 = jax.grad(sum_op(taichi_csr_matvec))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, - # shape=shape, transpose=transpose).sum(), - # argnums=0) - # csr_f2 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, - # shape=shape, transpose=transpose)[0].sum(), - # argnums=0) - # r1 = csr_f1(homo_data) - # r2 = csr_f2(homo_data) self.assertTrue(bm.allclose(r1, r2)) # print('grad vector start') @@ -200,7 +192,7 @@ def test_heter(self, transpose, shape): dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(compare_with_nan_tolerance(r1, r2)) @@ -228,7 +220,7 @@ def test_heter_vmap(self, transpose, shape): f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) - r1 = jax.vmap(f1)(heter_data) + r1 = jax.vmap(f1)(dense_data) r2 = jax.vmap(f2)(heter_data) self.assertTrue(compare_with_nan_tolerance(r1, r2)) @@ -252,27 +244,26 @@ def test_heter_grad(self, transpose, shape): # grad 'data' dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), argnums=0) - r1 = dense_f1(dense_data) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + csr_f1 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, + shape=shape, + transpose=transpose).sum(), + argnums=0) + r1 = csr_f1(heter_data) + r2 = dense_f1(dense_data) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + r2 = r2[rows, cols] + print(r1.shape, r2.shape) self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0) + csr_f2 = jax.grad(lambda v: taichi_csr_matvec(heter_data, indices, indptr, v, + shape=shape, + transpose=transpose).sum(), + argnums=0) r3 = dense_f2(vector) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r4 = csr_f2(vector) self.assertTrue(bm.allclose(r3, r4)) - dense_f3 = jax.grad(lambda a, v: ((v @ (dense_data * a)).sum() - if transpose else - ((dense_data * a) @ v).sum()), - argnums=(0, 1)) - r5 = dense_f3(heter_data, vector) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - bm.clear_buffer_memory() From c02781743a71229bfb8e7789e679f99ca4351d4c Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 20:11:15 +0800 Subject: [PATCH 25/27] [math] Fix test case bugs --- .../math/jitconn/tests/test_event_matvec.py | 31 - .../_src/math/jitconn/tests/test_matvec.py | 992 +++++++++--------- 2 files changed, 479 insertions(+), 544 deletions(-) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 0ef8947bf..8f8ee9bee 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -66,16 +66,6 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') # indices = bm.as_jax(indices) # indptr = bm.as_jax(indptr) @@ -245,16 +235,6 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) if x64: bm.disable_x64() bm.clear_buffer_memory() @@ -427,17 +407,6 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) - if x64: bm.disable_x64() bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index af91c2fff..0b091d72f 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -14,522 +14,488 @@ taichi_mv_prob_uniform = bm.jitconn.mv_prob_uniform taichi_mv_prob_normal = bm.jitconn.mv_prob_normal + class Test_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - homo_data=homo_data, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for homo_data in [-1., 1.] - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=None, x64=False): - print(f'test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = taichi_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2)) - - r2 = taichi_mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - self.assertTrue(jnp.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - )[0] + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_homo, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}, ' + f'x64 = {x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + homo_data=homo_data, + seed=1234) + for x64 in [True, False] + for transpose in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for homo_data in [-1., 1.] ) - r1 = f1(events, weights) - r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum(), - argnums=0 + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=None, x64=False): + print(f'test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = taichi_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = taichi_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + self.assertTrue(jnp.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - - self.assertTrue(jnp.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - x64=x64, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - - r1 = f1(events) - r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - f1 = jax.grad( - lambda e, w_low, w_high: taichi_mv_prob_uniform( - e, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + print(f'test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: taichi_mv_prob_homo( + event, data, + conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose + )[0] + ) + r1 = f1(events, weights) + r2 = f1(events, weights) + self.assertTrue(jnp.allclose(r1, r2)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - - r1 = f1(events, 0., 1.) - r2 = f1(events, 0., 2.) - - self.assertTrue(bm.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=(f'test_normal, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma},' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - seed=1234 + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: taichi_mv_prob_homo( + event, data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum(), + argnums=0 + ) + r1 = f1(events, 1.) + r2 = f1(events, 2.) + + self.assertTrue(jnp.allclose(r1 * 2., r2)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_uniform, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}' + f'x64 = {x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + w_low=w_low, + w_high=w_high, + x64=x64, + seed=1234) + for x64 in [True, False] + for transpose in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=None, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - print(r1 - r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64, - testcase_name=f'test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None, x64=False): + print(f'test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64 = {x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + c = jnp.allclose(r1, r2) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}', + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - self.assertTrue(bm.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + print(f'test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) + + r1 = f1(events) + r2 = f1(events) + self.assertTrue(jnp.allclose(r1, r2)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'x64={x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64) + for x64 in [True, False] + for transpose in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + f1 = jax.grad( + lambda e, w_low, w_high: taichi_mv_prob_uniform( + e, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() + ) + + r1 = f1(events, 0., 1.) + r2 = f1(events, 0., 2.) + + self.assertTrue(bm.allclose(r1 * 2., r2)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=(f'test_normal, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu = {w_mu}, ' + f'w_sigma = {w_sigma},' + f'x64={x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + w_mu=w_mu, + w_sigma=w_sigma, + seed=1234 + ) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=None, x64=False): + print(f'_test_normal: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu = {w_mu}, ' + f'w_sigma = {w_sigma}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + c = jnp.allclose(r1, r2) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'x64={x64}', + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) + r1 = f1(events) + r2 = f1(events) + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + print(r1 - r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64, + testcase_name=f'test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_sigma: taichi_mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() + ) + r1 = f1(events, 1.) + r2 = f1(events, 2.) + self.assertTrue(bm.allclose(r1 * 2., r2)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() From 12d045d97979f74b2f1ad75e51d91408debb13c9 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 22:07:07 +0800 Subject: [PATCH 26/27] [math] Add more tolerance for jitconn operators --- .../math/jitconn/tests/test_event_matvec.py | 995 +++++++++--------- .../_src/math/jitconn/tests/test_matvec.py | 34 +- 2 files changed, 515 insertions(+), 514 deletions(-) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 8f8ee9bee..b10d55d21 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -14,506 +14,507 @@ taichi_mv_prob_uniform = bm.jitconn.event_mv_prob_uniform taichi_mv_prob_normal = bm.jitconn.event_mv_prob_normal + class Test_event_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - homo_data=[-1., ], - bool_event=[True, False], - seed=[1234], - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=None, x64=False): - print(f'_test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') - # indices = bm.as_jax(indices) - # indptr = bm.as_jax(indptr) - # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, - # shape=shape, transpose=transpose) - # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - bool_event=[True, False], - seed=[1234], - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=None, x64=False): - print(f'_test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - transpose=transpose, outdim_parallel=outdim_parallel - )[0] + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.01, 0.1, 0.5], + homo_data=[-1., ], + bool_event=[True, False], + seed=[1234], ) - r1 = f1(events, weights) - r1 = jax.block_until_ready(r1) - r2 = f1(events, weights) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.5] - ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.5 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), - argnums=0 + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): + print(f'_test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + + r1 = taichi_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = taichi_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') + # indices = bm.as_jax(indices) + # indptr = bm.as_jax(indptr) + # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, + # shape=shape, transpose=transpose) + # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.01, 0.1, 0.5], + bool_event=[True, False], + seed=[1234], ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - - r3 = f1(events, 3.) - r3 = jax.block_until_ready(r3) - - self.assertTrue(jnp.allclose(r1 * 3., r3)) - self.assertTrue(jnp.allclose(r1 * 2., r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - bool_event=bool_event, - seed=1234, - x64=x64 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.4] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - for bool_event in [True, False] - ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, - bool_event=True, seed=None, x64=False): - print(f'_test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_uniform_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'bool_event={bool_event}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=None, x64=False): - print(f'_test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap( - lambda e: taichi_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): + print(f'_test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: taichi_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + transpose=transpose, outdim_parallel=outdim_parallel + )[0] + ) + r1 = f1(events, weights) + r1 = jax.block_until_ready(r1) + r2 = f1(events, weights) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}', + shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, seed=1234, + x64=x64) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1, 0.5] ) - - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - testcase_name=f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_high: taichi_mv_prob_uniform( - e, - w_low=0., - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.5 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: taichi_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), + argnums=0 + ) + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + + r3 = f1(events, 3.) + r3 = jax.block_until_ready(r3) + + self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}', + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + w_low=w_low, + w_high=w_high, + bool_event=bool_event, + seed=1234, + x64=x64 + ) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1, 0.4] + for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] + for bool_event in [True, False] ) - - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2., r2)) - # print(r1) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu={w_mu}, ' - f'w_sigma={w_sigma}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, ] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - for bool_event in [True, False] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, - bool_event=True, seed=None, x64=False): - print(f'_test_normal: shape = {shape}, ' - f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' - f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=None, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.jit( - jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel, prob=prob, + bool_event=bool_event, + x64=x64, + seed=1234, + testcase_name=f'_test_uniform_vmap: ' + f'shape={shape}, ' + f'transpose={transpose}, ' + f'bool_event={bool_event}, ' + f'outdim_parallel={outdim_parallel}, ' + f'prob={prob}, ' + f'x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for bool_event in [True, False] + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap( + lambda e: taichi_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + ) + + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + testcase_name=f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_high: taichi_mv_prob_uniform( + e, + w_low=0., + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() + ) + + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + # print(r1) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + w_mu=w_mu, + w_sigma=w_sigma, + bool_event=bool_event, + x64=x64, + seed=1234, + testcase_name=f'_test_normal: ' + f'shape={shape}, ' + f'transpose={transpose}, ' + f'outdim_parallel={outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu={w_mu}, ' + f'w_sigma={w_sigma}, ' + f'bool_event={bool_event}, ' + f'x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1, ] + for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] + for bool_event in [True, False] + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal: shape = {shape}, ' + f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' + f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + bool_event=bool_event, + x64=x64, + seed=1234, + testcase_name=f'_test_normal_vmap: ' + f'shape={shape}, ' + f'transpose={transpose}, ' + f'outdim_parallel={outdim_parallel}, ' + f'prob={prob}, ' + f'bool_event={bool_event}, ' + f'x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for bool_event in [True, False] + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + x64=x64, + seed=1234, + testcase_name=f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.jit( + jax.grad( + lambda e, w_sigma: taichi_mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() + ) + ) + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 0b091d72f..2e6e406cf 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -41,7 +41,7 @@ def __init__(self, *args, platform='cpu', **kwargs): for prob in [0.01, 0.1] for homo_data in [-1., 1.] ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=None, x64=False): + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): print(f'test_homo: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -70,7 +70,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2)) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) bm.clear_buffer_memory() @@ -91,7 +91,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non for shape in shapes for prob in [0.01, 0.1] ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): print(f'test_homo_vmap: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -114,7 +114,7 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64 ) r1 = f1(events, weights) r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2)) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) if x64: bm.disable_x64() @@ -137,7 +137,7 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64 for shape in shapes for prob in [0.01, 0.1] ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): print(f'_test_homo_grad: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -164,7 +164,7 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 r1 = f1(events, 1.) r2 = f1(events, 2.) - self.assertTrue(jnp.allclose(r1 * 2., r2)) + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) if x64: bm.disable_x64() @@ -193,7 +193,7 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 for prob in [0.01, 0.1] for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None, x64=False): + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): print(f'test_uniform: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -225,7 +225,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2) + c = jnp.allclose(r1, r2, atol=1e-6) if not c: print(r1, r2) self.assertTrue(c) @@ -251,7 +251,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s for shape in shapes for prob in [0.01, 0.1] ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): print(f'test_uniform_vmap: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -274,7 +274,7 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, r1 = f1(events) r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2)) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) if x64: bm.disable_x64() @@ -298,7 +298,7 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, for shape in shapes for prob in [0.01, 0.1] ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): print(f'_test_uniform_grad: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -327,7 +327,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, r1 = f1(events, 0., 1.) r2 = f1(events, 0., 2.) - self.assertTrue(bm.allclose(r1 * 2., r2)) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) if x64: bm.disable_x64() @@ -357,7 +357,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, for prob in [0.01, 0.1] for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=None, x64=False): + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): print(f'_test_normal: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -389,7 +389,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2) + c = jnp.allclose(r1, r2, atol=1e-6) if not c: print(r1, r2) self.assertTrue(c) @@ -415,7 +415,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se for shape in shapes for prob in [0.01, 0.1] ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): print(f'_test_normal_vmap: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -467,7 +467,7 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x for shape in shapes for prob in [0.01, 0.1] ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): print(f'_test_normal_grad: ' f'shape = {shape}, ' f'transpose = {transpose}, ' @@ -494,7 +494,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x ) r1 = f1(events, 1.) r2 = f1(events, 2.) - self.assertTrue(bm.allclose(r1 * 2., r2)) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) if x64: bm.disable_x64() From e1f40054a51ec6a5b6c5faa4f3f3ddd11030786f Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Mon, 29 Jan 2024 23:14:03 +0800 Subject: [PATCH 27/27] format the code --- brainpy/_src/math/event/_csr_matvec.py | 23 ++++++++++-------- brainpy/_src/math/jitconn/_event_matvec.py | 28 +++++++++++++++------- brainpy/_src/math/jitconn/_matvec.py | 15 +++++++----- brainpy/_src/math/sparse/_csr_mv.py | 11 +++++---- 4 files changed, 48 insertions(+), 29 deletions(-) diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index c9d3d80ba..2e7895334 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -10,11 +10,9 @@ """ - from functools import partial from typing import Union, Tuple -import brainpy.math as bm import jax import jax.numpy as jnp import numba @@ -23,6 +21,7 @@ from jax.interpreters import ad, xla from jax.lib import xla_client +from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, @@ -31,7 +30,6 @@ from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) from brainpy.errors import GPUOperatorNotFound __all__ = [ @@ -159,6 +157,7 @@ def csrmv_brainpylib( # computing return event_csr_matvec_p.bind(data, indices, indptr, events, shape=shape, transpose=transpose) + # ---------------------------------------------------------- # event csr matvec # ---------------------------------------------------------- @@ -600,9 +599,12 @@ def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p)) xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation -ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, _event_csr_matvec_jvp_events_brainpylib) +ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, + _event_csr_matvec_jvp_events_brainpylib) ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib register_general_batching(event_csr_matvec_p) + + # batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule @@ -688,6 +690,7 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] + # ------------- # CPU operators # ------------- @@ -958,7 +961,7 @@ def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), if events[indices[j]]: r += values[j] j += 32 - out[row_i] += r # TODO: warp-level primitive + out[row_i] += r # TODO: warp-level primitive @ti.kernel @@ -977,7 +980,8 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), if events[indices[j]] != 0.: r += values[j] j += 32 - out[row_i] += r # TODO: warp-level primitive + out[row_i] += r # TODO: warp-level primitive + def raw_csrmv_taichi( data: Union[float, jax.Array], @@ -1020,6 +1024,7 @@ def raw_csrmv_taichi( transpose=transpose, shape=shape) + def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) @@ -1047,6 +1052,8 @@ def _event_csr_matvec_transpose_taichi( row, col = csr_to_coo(indices, indptr) ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] return ct_values, indices, indptr, events + + def _define_op(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) @@ -1080,7 +1087,3 @@ def _define_op(cpu_kernel, gpu_kernel): # not transpose heter _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) - - - - diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 21b082fc5..7971b4a92 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -3,7 +3,6 @@ from functools import partial from typing import Tuple, Optional -import brainpy.math as bm import jax import numpy as np from jax import numpy as jnp, dtypes @@ -19,12 +18,12 @@ mv_prob_homo, mv_prob_uniform, mv_prob_normal, - _general_checking, - raw_mv_prob_homo, - raw_mv_prob_uniform, + _general_checking, + raw_mv_prob_homo, + raw_mv_prob_uniform, raw_mv_prob_normal, - _mv_prob_homo_transpose, - _mv_prob_uniform_transpose, + _mv_prob_homo_transpose, + _mv_prob_uniform_transpose, _mv_prob_normal_transpose, _reverse) from brainpy._src.math.ndarray import _get_dtype @@ -51,7 +50,9 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ @@ -67,7 +68,9 @@ def event_mv_prob_uniform( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ @@ -83,7 +86,9 @@ def event_mv_prob_normal( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + ### BRAINPYLIB ### @@ -180,6 +185,7 @@ def event_mv_prob_normal_brainpylib( transpose=transpose, outdim_parallel=outdim_parallel)[0] + event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ @@ -872,6 +878,7 @@ def event_mv_prob_uniform_taichi( return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + def event_mv_prob_normal_taichi( events: jax.Array, w_mu: float, @@ -947,6 +954,7 @@ def event_mv_prob_normal_taichi( return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + # ------------- # CPU function # ------------- @@ -1075,9 +1083,11 @@ def _event_mv_prob_homo_outdim_parallel_bool_gpu( i_col += inc out[i_row] += r # TODO: warp-level reduction + def _reverse(shape): return shape[::-1] + # ------------- # CPU function # ------------- diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index f20c12f02..e33a0ab1e 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -4,7 +4,6 @@ from functools import partial from typing import Tuple, Optional, Union -import brainpy.math as bm import jax import numpy as np from jax import numpy as jnp, dtypes @@ -86,8 +85,8 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - + return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) def mv_prob_uniform( @@ -151,7 +150,8 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) def mv_prob_normal( @@ -215,7 +215,8 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) ### BRAINYPLIB ### @@ -456,7 +457,6 @@ def mv_prob_normal_brainpylib( outdim_parallel=outdim_parallel)[0] - def _matvec_prob_homo_abstract( vector, weight, clen, seed, *, shape, transpose, outdim_parallel ): @@ -1095,6 +1095,7 @@ def mv_prob_homo_taichi( return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + def mv_prob_uniform_taichi( vector: jax.Array, w_low: float, @@ -1170,6 +1171,7 @@ def mv_prob_uniform_taichi( return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + def mv_prob_normal_taichi( vector: jax.Array, w_mu: float, @@ -1245,6 +1247,7 @@ def mv_prob_normal_taichi( return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + def _reverse(shape): return shape[::-1] diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index acd68b999..47704af04 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -81,8 +81,9 @@ def csrmv( else: return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method) + ### BRAINPYLIB ### - + def csrmv_brainpylib( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], @@ -164,6 +165,7 @@ def csrmv_brainpylib( else: raise ValueError(f'Only support methods: cusparse, scalar, vector, and adaptive. But we got {method}.') + def _csrmv_abstract(data, indices, indptr, vector, *, shape, transpose): if data.dtype not in [jnp.float32, jnp.float64]: raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.') @@ -587,7 +589,7 @@ def csrmv_taichi( # if the shape of indices is (0,), then we return a zero vector if indices.shape[0] == 0: return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - + return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] @@ -755,6 +757,7 @@ def _sparse_csr_matvec_transpose( return ct_data, indices, indptr, vector + def raw_csrmv_taichi( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], @@ -783,7 +786,7 @@ def raw_csrmv_taichi( outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], transpose=transpose, shape=shape) - + def _define_op(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) @@ -806,4 +809,4 @@ def _define_op(cpu_kernel, gpu_kernel): # no transpose heter _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) \ No newline at end of file + gpu_kernel=_sparse_csr_matvec_heter_gpu)