diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 2570835f..a1d31e08 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -980,7 +980,15 @@ def __init__( self.sharding = sharding -class JitFPHomoLinear(Layer): +class JitFPLinear(Layer): + def get_connect_matrix(self): + return bm.jitconn.get_conn_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPHomoLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1058,14 +1066,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class JitFPUniformLinear(Layer): +class JitFPUniformLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1144,14 +1146,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class JitFPNormalLinear(Layer): +class JitFPNormalLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1230,14 +1226,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPHomoLinear(Layer): +class EventJitFPHomoLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1315,14 +1305,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPUniformLinear(Layer): +class EventJitFPUniformLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1401,14 +1385,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPNormalLinear(Layer): +class EventJitFPNormalLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1486,9 +1464,3 @@ def _batch_mv(self, x): shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) - - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index ad168133..8a7ba398 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -20,7 +20,7 @@ 'mv_prob_homo', 'mv_prob_uniform', 'mv_prob_normal', - 'get_connect_matrix', + 'get_conn_matrix', ] @@ -258,7 +258,7 @@ def mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def get_connect_matrix( +def get_conn_matrix( conn_prob: float, seed: Optional[int] = None, *, diff --git a/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py similarity index 87% rename from brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py rename to brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py index 4b948fc2..a58be6e8 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py @@ -32,10 +32,10 @@ def __init__(self, *args, platform='cpu', **kwargs): shape=shapes, prob=[0.1], ) - def test_get_connect_matrix(self, transpose, outdim_parallel, shape, prob): + def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): print( f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_connect_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + conn = bm.jitconn.get_conn_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape assert conn.shape == shape assert conn.dtype == jnp.bool_ diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 441b19f2..e1c4eafb 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -7,6 +7,6 @@ mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, - get_connect_matrix as get_connect_matrix, + get_conn_matrix as get_conn_matrix, )