Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 28, 2024
1 parent dd084a7 commit 11a6624
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 47 deletions.
56 changes: 14 additions & 42 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,15 @@ def __init__(
self.sharding = sharding


class JitFPHomoLinear(Layer):
class JitFPLinear(Layer):
def get_connect_matrix(self):
return bm.jitconn.get_conn_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPHomoLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1058,14 +1066,8 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPUniformLinear(Layer):
class JitFPUniformLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1144,14 +1146,8 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPNormalLinear(Layer):
class JitFPNormalLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1230,14 +1226,8 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class EventJitFPHomoLinear(Layer):
class EventJitFPHomoLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1315,14 +1305,8 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class EventJitFPUniformLinear(Layer):
class EventJitFPUniformLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1401,14 +1385,8 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class EventJitFPNormalLinear(Layer):
class EventJitFPNormalLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1486,9 +1464,3 @@ def _batch_mv(self, x):
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
4 changes: 2 additions & 2 deletions brainpy/_src/math/jitconn/matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
'mv_prob_homo',
'mv_prob_uniform',
'mv_prob_normal',
'get_connect_matrix',
'get_conn_matrix',
]


Expand Down Expand Up @@ -258,7 +258,7 @@ def mv_prob_normal(
transpose=transpose, outdim_parallel=outdim_parallel)[0]


def get_connect_matrix(
def get_conn_matrix(
conn_prob: float,
seed: Optional[int] = None,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def __init__(self, *args, platform='cpu', **kwargs):
shape=shapes,
prob=[0.1],
)
def test_get_connect_matrix(self, transpose, outdim_parallel, shape, prob):
def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob):
print(
f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}')
conn = bm.jitconn.get_connect_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
conn = bm.jitconn.get_conn_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
shape = (shape[1], shape[0]) if transpose else shape
assert conn.shape == shape
assert conn.dtype == jnp.bool_
Expand Down
2 changes: 1 addition & 1 deletion brainpy/math/jitconn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
mv_prob_uniform as mv_prob_uniform,
mv_prob_normal as mv_prob_normal,

get_connect_matrix as get_connect_matrix,
get_conn_matrix as get_conn_matrix,
)

0 comments on commit 11a6624

Please sign in to comment.