diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 729fdc70..d1dbd308 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -979,22 +979,25 @@ def __init__( self.conn = conn self.sharding = sharding +class JitLinear(Layer): + def get_conn_matrix(self): + pass -class JitFPHomoLayer(Layer): +class JitFPHomoLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_conn_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPUniformLayer(Layer): +class JitFPUniformLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPNormalLayer(Layer): +class JitFPNormalLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index e2ae2322..08310506 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- - - +import numbers from typing import Tuple, Optional, Union import jax @@ -9,6 +8,7 @@ from jax.interpreters import ad from brainpy._src.dependency_check import import_taichi +from brainpy._src.math.defaults import float_ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_register import XLACustomOp @@ -20,7 +20,7 @@ 'mv_prob_homo', 'mv_prob_uniform', 'mv_prob_normal', - 'get_conn_matrix', + 'get_homo_weight_matrix', 'get_uniform_weight_matrix', 'get_normal_weight_matrix' ] @@ -260,7 +260,8 @@ def mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def get_conn_matrix( +def get_homo_weight_matrix( + weight: float, conn_prob: float, seed: Optional[int] = None, *, @@ -290,6 +291,10 @@ def get_conn_matrix( out: Array, ndarray The connection matrix :math:`M`. """ + if isinstance(weight, numbers.Number): + weight = jnp.atleast_1d(jnp.asarray(weight, dtype=float_)) + else: + raise ValueError(f'weight must be a number type, but get {type(weight)}') if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') @@ -299,8 +304,9 @@ def get_conn_matrix( with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - r = raw_get_connect_matrix(conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + r = raw_get_homo_weight_matrix(conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + r *= weight if transpose: return r.transpose() else: @@ -501,7 +507,7 @@ def raw_mv_prob_normal( outdim_parallel=outdim_parallel) -def raw_get_connect_matrix( +def raw_get_homo_weight_matrix( conn_len: jax.Array, seed: jax.Array, *, diff --git a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py similarity index 97% rename from brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py rename to brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py index bf94065e..fca88d6d 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py @@ -33,19 +33,20 @@ def __init__(self, *args, platform='cpu', **kwargs): prob=[0.1], ) def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): + homo_data = 1. print( f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_conn_matrix(prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + conn = bm.jitconn.get_homo_weight_matrix(homo_data, prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape print(conn.shape) assert conn.shape == shape - assert conn.dtype == jnp.bool_ + # assert conn.dtype == jnp.float_ # sum all true values print( f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') # compare with jitconn op - homo_data = 1. + rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 79a9bc87..3c99b7de 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -7,7 +7,7 @@ mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, - get_conn_matrix as get_conn_matrix, + get_homo_weight_matrix as get_homo_weight_matrix, get_uniform_weight_matrix as get_uniform_weight_matrix, get_normal_weight_matrix as get_normal_weight_matrix, )