diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 09bf2958..b837dd92 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import numba import numpy as np +import taichi as ti from brainpy import math as bm from brainpy._src import connect, initialize as init @@ -29,6 +30,9 @@ 'CSRLinear', 'EventCSRLinear', 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear', 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear', + 'CSRLinear_taichi', 'EventCSRLinear_taichi', + 'JitFPHomoLinear_taichi', 'JitFPUniformLinear_taichi', 'JitFPNormalLinear_taichi', + 'EventJitFPHomoLinear_taichi', 'EventJitFPNormalLinear_taichi', 'EventJitFPUniformLinear_taichi', ] @@ -660,9 +664,11 @@ def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) + csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update) + def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf @@ -1278,3 +1284,705 @@ def _batch_mv(self, x): shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) + +### TAICHI CUSTOMIZED OPERATOR IMPLEMENTATION ### + +class _CSRLayer_taichi(Layer, SupportSTDP): + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + assert sharding is None, 'Currently this model does not support sharding.' + self.conn = conn + self.sharding = sharding + self.transpose = transpose + + # connection + self.indices, self.indptr = self.conn.require('csr') + + # weight + weight = init.parameter(weight, (self.indices.size,)) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if bm.isscalar(self.weight): + raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') + if self.weight.shape != self.indices.shape: + raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: # update on presynaptic spike + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = csr_on_pre_update_taichi(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) + if on_post is not None: # update on postsynaptic spike + if not hasattr(self, '_pre_ids'): + with jax.ensure_compile_time_eval(): + self._pre_ids, self._post_indptr, self.w_indices = csr2csc( + [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) + ) + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = csc_on_post_update_taichi(self.weight.value, self._pre_ids, self._post_indptr, + self.w_indices, spike, trace, w_min, w_max) + + +class CSRLinear_taichi(_CSRLayer_taichi): + r"""Synaptic matrix multiplication with CSR sparse computation(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + + def update(self, x): + if x.ndim == 1: + return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose)[0] + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.sparse.csrmv_taichi(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose)[0] + +class EventCSRLinear_taichi(_CSRLayer_taichi): + r"""Synaptic matrix multiplication with event CSR sparse computation(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + + def update(self, x): + if x.ndim == 1: + return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose)[0] + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.event.csrmv_taichi(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose)[0] + +@ti.kernel +def _cpu_csr_on_pre_update_taichi(w: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + w_value = w[0] + out_w[:] = w_value + w_min_value = w_min[0] + w_max_value = w_max[0] + for i in range(spike.shape[0]): # pre id + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): # synapse id + j = indices[k] # post id + out_w[k] = ti.min(ti.max(out_w[k] + trace[j], w_min_value), w_max_value) + +csr_on_pre_update_prim_taichi = bm.XLACustomOp(_cpu_csr_on_pre_update_taichi) + + +def csr_on_pre_update_taichi(w, indices, indptr, spike, trace, w_min=None, w_max=None): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w = jax.Array(w) + w_min = jax.Array(w_min) + w_max = jax.Array(w_max) + return csr_on_pre_update_prim_taichi(w, indices, indptr, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + +@ti.kernel +def _cpu_csc_on_pre_update_taichi(w: ti.types.ndarray(ndim=1), + post_ids: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + w_ids: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + w_value = w[0] + out_w[:] = w_value + w_min_value = w_min[0] + w_max_value = w_max[0] + + for i in range(spike.shape[0]): # post id + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = post_ids[k] # pre id + l = w_ids[k] # syn id + out_w[l] = ti.min(ti.max(out_w[l] + trace[j], w_min_value), w_max_value) + +csc_on_pre_update_prim_taichi = bm.XLACustomOp(_cpu_csc_on_pre_update_taichi) + + +def csc_on_post_update_taichi(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w = jax.Array(w) + w_min = jax.Array(w_min) + w_max = jax.Array(w_max) + return csc_on_pre_update_prim_taichi(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + +class JitFPHomoLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPUniformLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPNormalLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPHomoLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 1000000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_homo_taichi(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPUniformLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_uniform_taichi(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPNormalLinear_taichi(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity(taichi customized operator). + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_normal_taichi(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index da49bdbf..48fc60a0 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -213,6 +213,133 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): self.assertTrue(y2.shape == shape + (200,)) bm.clear_buffer_memory() + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_CSRLinear_taichi(self, conn): + bm.random.seed() + f = bp.dnn.CSRLinear_taichi(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_EventCSRLinear_taichi(self,conn): + bm.random.seed() + f=bp.layers.EventCSRLinear_taichi(conn,weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPHomoLinear_taichi(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.JitFPHomoLinear_taichi(100, 200, prob, weight, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPUniformLinear_taichi(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.JitFPUniformLinear_taichi(100, 200, prob, w_low, w_high, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPNormalLinear_taichi(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.JitFPNormalLinear_taichi(100, 200, prob, w_mu, w_sigma, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPHomoLinear_taichi(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.EventJitFPHomoLinear_taichi(100, 200, prob, weight, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPUniformLinear_taichi(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.EventJitFPUniformLinear_taichi(100, 200, prob, w_low, w_high, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPNormalLinear_taichi(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.EventJitFPNormalLinear_taichi(100, 200, prob, w_mu, w_sigma, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + if __name__ == '__main__': absltest.main() diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index 762c3c28..44a51b9d 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -14,4 +14,12 @@ EventJitFPHomoLinear as EventJitFPHomoLinear, EventJitFPNormalLinear as EventJitFPNormalLinear, EventJitFPUniformLinear as EventJitFPUniformLinear, + CSRLinear_taichi as CSRLinear_taichi, + EventCSRLinear_taichi as EventCSRLinear_taichi, + JitFPHomoLinear_taichi as JitFPHomoLinear_taichi, + JitFPUniformLinear_taichi as JitFPUniformLinear_taichi, + JitFPNormalLinear_taichi as JitFPNormalLinear_taichi, + EventJitFPHomoLinear_taichi as EventJitFPHomoLinear_taichi, + EventJitFPNormalLinear_taichi as EventJitFPNormalLinear_taichi, + EventJitFPUniformLinear_taichi as EventJitFPUniformLinear_taichi, )