From dd084a7a75fac9af5f3fcefd9e65d6cd77a668c1 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 27 May 2024 12:59:31 +0800 Subject: [PATCH 1/3] [math] Add get JIT connect matrix methods for `brainpy.dnn.linear` --- brainpy/_src/dnn/linear.py | 36 ++++++ brainpy/_src/math/jitconn/matvec.py | 111 ++++++++++++++++++ .../jitconn/tests/test_get_connect_matrix.py | 46 ++++++++ brainpy/_src/math/op_register/base.py | 13 +- brainpy/math/jitconn.py | 2 + 5 files changed, 203 insertions(+), 5 deletions(-) create mode 100644 brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index c524fb0b..2570835f 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1058,6 +1058,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class JitFPUniformLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1138,6 +1144,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class JitFPNormalLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1218,6 +1230,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class EventJitFPHomoLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1297,6 +1315,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class EventJitFPUniformLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1377,6 +1401,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class EventJitFPNormalLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1456,3 +1486,9 @@ def _batch_mv(self, x): shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) + + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 00e5778f..ad168133 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -20,6 +20,7 @@ 'mv_prob_homo', 'mv_prob_uniform', 'mv_prob_normal', + 'get_connect_matrix', ] @@ -257,6 +258,49 @@ def mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] +def get_connect_matrix( + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`. + + Parameters + ---------- + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The connection matrix :math:`M`. + """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_get_connect_matrix(conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + + def raw_mv_prob_homo( vector: jax.Array, weight: jax.Array, # vector with size 1 @@ -342,6 +386,28 @@ def raw_mv_prob_normal( outdim_parallel=outdim_parallel) +def raw_get_connect_matrix( + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + out_shape = shape if not transpose else (shape[1], shape[0]) + if outdim_parallel: + prim = _get_connect_matrix_outdim_parallel_p + else: + prim = _get_connect_matrix_p + + return prim(conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=jnp.int32)], + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): if vector.ndim != 1: raise ValueError('vector should be a 1D vector.') @@ -918,3 +984,48 @@ def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): cpu_kernel=_mv_prob_normal_cpu, gpu_kernel=_mv_prob_normal_gpu ) + + + @ti.kernel + def _get_connect_matrix( + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row, i_col] = 1 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _get_connect_matrix_outdim_parallel( + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + out[i_row, i_col] = 1 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + + + _get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix) + _get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel, + gpu_kernel=_get_connect_matrix_outdim_parallel) diff --git a/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py new file mode 100644 index 00000000..4b948fc2 --- /dev/null +++ b/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +import jax.numpy as jnp +import pytest +from absl.testing import parameterized + +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +import platform + +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + +shapes = [(100, 200), (1000, 10)] + + +# SEED = 1234 + +class TestGetConnectMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetConnectMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_get_connect_matrix(self, transpose, outdim_parallel, shape, prob): + print( + f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') + conn = bm.jitconn.get_connect_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + assert conn.shape == shape + assert conn.dtype == jnp.bool_ + # sum all true values + # assert jnp.sum(conn) == jnp.round(prob * shape[0] * shape[1]) + print( + f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + # print(f'conn: {conn}') diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 5af5a7e3..20a48778 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -13,14 +13,16 @@ from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, - register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) + from .cupy_based import ( + register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) else: from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, - register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) + from .cupy_based import ( + register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp @@ -116,7 +118,8 @@ def __init__( register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True if not gpu_checked: - raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') + raise ValueError( + f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index a87d27d5..441b19f2 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -6,5 +6,7 @@ mv_prob_homo as mv_prob_homo, mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, + + get_connect_matrix as get_connect_matrix, ) From 11a6624d2c7b68cce6a4ece389190649ff8fa025 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 28 May 2024 11:05:28 +0800 Subject: [PATCH 2/3] Update --- brainpy/_src/dnn/linear.py | 56 +++++-------------- brainpy/_src/math/jitconn/matvec.py | 4 +- ...nect_matrix.py => test_get_conn_matrix.py} | 4 +- brainpy/math/jitconn.py | 2 +- 4 files changed, 19 insertions(+), 47 deletions(-) rename brainpy/_src/math/jitconn/tests/{test_get_connect_matrix.py => test_get_conn_matrix.py} (87%) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 2570835f..a1d31e08 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -980,7 +980,15 @@ def __init__( self.sharding = sharding -class JitFPHomoLinear(Layer): +class JitFPLinear(Layer): + def get_connect_matrix(self): + return bm.jitconn.get_conn_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPHomoLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1058,14 +1066,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class JitFPUniformLinear(Layer): +class JitFPUniformLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1144,14 +1146,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class JitFPNormalLinear(Layer): +class JitFPNormalLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1230,14 +1226,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPHomoLinear(Layer): +class EventJitFPHomoLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1315,14 +1305,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPUniformLinear(Layer): +class EventJitFPUniformLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1401,14 +1385,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPNormalLinear(Layer): +class EventJitFPNormalLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1486,9 +1464,3 @@ def _batch_mv(self, x): shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) - - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index ad168133..8a7ba398 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -20,7 +20,7 @@ 'mv_prob_homo', 'mv_prob_uniform', 'mv_prob_normal', - 'get_connect_matrix', + 'get_conn_matrix', ] @@ -258,7 +258,7 @@ def mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def get_connect_matrix( +def get_conn_matrix( conn_prob: float, seed: Optional[int] = None, *, diff --git a/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py similarity index 87% rename from brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py rename to brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py index 4b948fc2..a58be6e8 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py @@ -32,10 +32,10 @@ def __init__(self, *args, platform='cpu', **kwargs): shape=shapes, prob=[0.1], ) - def test_get_connect_matrix(self, transpose, outdim_parallel, shape, prob): + def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): print( f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_connect_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + conn = bm.jitconn.get_conn_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape assert conn.shape == shape assert conn.dtype == jnp.bool_ diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 441b19f2..e1c4eafb 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -7,6 +7,6 @@ mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, - get_connect_matrix as get_connect_matrix, + get_conn_matrix as get_conn_matrix, ) From 3f46507a3f6bfd1ca7fb23418210ea1fff8d4b76 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 28 May 2024 11:23:43 +0800 Subject: [PATCH 3/3] Update linear.py --- brainpy/_src/dnn/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index a1d31e08..a750142a 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -981,7 +981,7 @@ def __init__( class JitFPLinear(Layer): - def get_connect_matrix(self): + def get_conn_matrix(self): return bm.jitconn.get_conn_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose,