From 0e6dbc2210064a2651ddc67bd0c9e23aa498da5f Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 29 May 2024 17:05:55 +0800 Subject: [PATCH] [math] Add get JIT weight matrix methods(Uniform & Normal) for `brainpy.dnn.linear` --- brainpy/_src/dnn/linear.py | 27 +- brainpy/_src/math/jitconn/matvec.py | 284 +++++++++++++++++- .../jitconn/tests/test_get_conn_matrix.py | 131 +++++++- brainpy/math/jitconn.py | 2 + 4 files changed, 426 insertions(+), 18 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index a750142a..03398766 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -980,15 +980,28 @@ def __init__( self.sharding = sharding -class JitFPLinear(Layer): +class JitFPHomoLinear(Layer): def get_conn_matrix(self): return bm.jitconn.get_conn_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) +class JitFPUniformLinear(Layer): + def get_conn_matrix(self): + return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + +class JitFPNormalLinear(Layer): + def get_conn_matrix(self): + return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) -class JitFPHomoLinear(JitFPLinear): +class JitFPHomoLinear(JitFPHomoLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1067,7 +1080,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPUniformLinear(JitFPLinear): +class JitFPUniformLinear(JitFPUniformLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1147,7 +1160,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPNormalLinear(JitFPLinear): +class JitFPNormalLinear(JitFPNormalLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1227,7 +1240,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPHomoLinear(JitFPLinear): +class EventJitFPHomoLinear(JitFPHomoLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1306,7 +1319,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPUniformLinear(JitFPLinear): +class EventJitFPUniformLinear(JitFPUniformLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1386,7 +1399,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPNormalLinear(JitFPLinear): +class EventJitFPNormalLinear(JitFPNormalLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 8a7ba398..e2ae2322 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -21,6 +21,8 @@ 'mv_prob_uniform', 'mv_prob_normal', 'get_conn_matrix', + 'get_uniform_weight_matrix', + 'get_normal_weight_matrix' ] @@ -297,8 +299,121 @@ def get_conn_matrix( with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_get_connect_matrix(conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + r = raw_get_connect_matrix(conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + if transpose: + return r.transpose() + else: + return r + + +def get_uniform_weight_matrix( + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Get the weight matrix :math:`M` with a uniform distribution for its value. + + Parameters + ---------- + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The weight matrix :math:`M`. + """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + r = raw_get_uniform_weight_matrix(w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + if transpose: + return r.transpose() + else: + return r + + +def get_normal_weight_matrix( + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Get the weight matrix :math:`M` with a normal distribution for its value. + + Parameters + ---------- + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The weight matrix :math:`M`. + """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + r = raw_get_normal_weight_matrix(w_mu, w_sigma, conn_len, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + if transpose: + return r.transpose() + else: + return r def raw_mv_prob_homo( @@ -394,7 +509,6 @@ def raw_get_connect_matrix( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - out_shape = shape if not transpose else (shape[1], shape[0]) if outdim_parallel: prim = _get_connect_matrix_outdim_parallel_p else: @@ -402,7 +516,57 @@ def raw_get_connect_matrix( return prim(conn_len, seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=jnp.int32)], + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.int32)], + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def raw_get_uniform_weight_matrix( + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + if outdim_parallel: + prim = _get_uniform_weight_matrix_outdim_parallel_p + else: + prim = _get_uniform_weight_matrix_p + + return prim(w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)], + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def raw_get_normal_weight_matrix( + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + if outdim_parallel: + prim = _get_normal_weight_matrix_outdim_parallel_p + else: + prim = _get_normal_weight_matrix_p + + return prim(w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)], shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -1029,3 +1193,115 @@ def _get_connect_matrix_outdim_parallel( _get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix) _get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel, gpu_kernel=_get_connect_matrix_outdim_parallel) + + + @ti.kernel + def _get_uniform_weight_matrix( + w_low: ti.types.ndarray(), + w_high: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_low0 = w_low[0] + w_high0 = w_high[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_low0, w_high0) + out[i_row, i_col] += raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _get_uniform_weight_matrix_outdim_parallel( + w_low: ti.types.ndarray(), + w_high: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_low0 = w_low[0] + w_high0 = w_high[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_low0, w_high0) + out[i_row, i_col] += raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + + + _get_uniform_weight_matrix_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix, + gpu_kernel=_get_uniform_weight_matrix) + _get_uniform_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix_outdim_parallel, + gpu_kernel=_get_uniform_weight_matrix_outdim_parallel) + + + @ti.kernel + def _get_normal_weight_matrix( + w_mu: ti.types.ndarray(), + w_sigma: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row, i_col] += raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _get_normal_weight_matrix_outdim_parallel( + w_mu: ti.types.ndarray(), + w_sigma: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row, i_col] += raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + + + _get_normal_weight_matrix_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix, + gpu_kernel=_get_normal_weight_matrix) + _get_normal_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix_outdim_parallel, + gpu_kernel=_get_normal_weight_matrix_outdim_parallel) diff --git a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py index a58be6e8..e671199d 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py @@ -12,13 +12,13 @@ import platform force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) +# if platform.system() == 'Windows' and not force_test: +# pytest.skip('skip windows', allow_module_level=True) -shapes = [(100, 200), (1000, 10)] +shapes = [(10, 20), (1000, 10)] +SEED = 1234 -# SEED = 1234 class TestGetConnectMatrix(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -35,12 +35,129 @@ def __init__(self, *args, platform='cpu', **kwargs): def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): print( f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_conn_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + conn = bm.jitconn.get_conn_matrix(prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape + print(conn.shape) assert conn.shape == shape assert conn.dtype == jnp.bool_ # sum all true values - # assert jnp.sum(conn) == jnp.round(prob * shape[0] * shape[1]) print( f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') - # print(f'conn: {conn}') + + # compare with jitconn op + homo_data = 1. + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=SEED, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = vector @ conn if transpose else conn @ vector + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + +class TestGetUniformWeightMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetUniformWeightMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[0.1], + w_high=[0.9], + ) + def test_get_uniform_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_low, w_high): + print( + f'test_get_uniform_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_low={w_low}, w_high={w_high}') + weight = bm.jitconn.get_uniform_weight_matrix(w_low, w_high, prob, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + assert weight.shape == shape + assert weight.dtype == jnp.float32 + + weight_true = weight > 0. + + print( + f'jnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + + # CANNOT BE TESTED IN THIS WAY, BECAUSE UNIFORM JITCONN OP HAS BEEN OPTIMIZED + # compare with jitconn op + + # rng = bm.random.RandomState() + # events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + # + # r1 = bm.jitconn.mv_prob_uniform(events, + # w_low=w_low, + # w_high=w_high, + # conn_prob=prob, + # shape=shape, + # seed=SEED, + # outdim_parallel=outdim_parallel, + # transpose=transpose) + # + # r2 = events @ weight if transpose else weight @ events + # print(f'r1: {r1}\n r2: {r2}') + # self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + +class TestGetNormalWeightMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetNormalWeightMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_mu=[0.0], + w_sigma=[1.0], + ) + def test_get_normal_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_mu, w_sigma): + print( + f'test_get_normal_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_mu={w_mu}, w_sigma={w_sigma}') + weight = bm.jitconn.get_normal_weight_matrix(w_mu, w_sigma, prob, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + assert weight.shape == shape + assert weight.dtype == jnp.float32 + + weight_true = weight > 0. + + print( + f'jnnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + + # CANNOT BE TESTED IN THIS WAY, BECAUSE UNIFORM JITCONN OP HAS BEEN OPTIMIZED + # compare with jitconn op + + # rng = bm.random.RandomState() + # vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + # + # r1 = bm.jitconn.mv_prob_normal(vector, + # w_mu=w_mu, + # w_sigma=w_sigma, + # conn_prob=prob, + # shape=shape, + # seed=SEED, + # outdim_parallel=outdim_parallel, + # transpose=transpose) + # + # r2 = vector @ weight if transpose else weight @ vector + # print(f'r1: {r1}\n r2: {r2}') + # self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index e1c4eafb..79a9bc87 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -8,5 +8,7 @@ mv_prob_normal as mv_prob_normal, get_conn_matrix as get_conn_matrix, + get_uniform_weight_matrix as get_uniform_weight_matrix, + get_normal_weight_matrix as get_normal_weight_matrix, )