Skip to content

Commit

Permalink
Update linear.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 30, 2024
1 parent 1aa42aa commit 452b2a3
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,28 +980,28 @@ def __init__(
self.sharding = sharding


class JitFPHomoLinear(Layer):
class JitFPHomoLayer(Layer):
def get_conn_matrix(self):
return bm.jitconn.get_conn_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

class JitFPUniformLinear(Layer):
class JitFPUniformLayer(Layer):
def get_conn_matrix(self):
return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

class JitFPNormalLinear(Layer):
class JitFPNormalLayer(Layer):
def get_conn_matrix(self):
return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

class JitFPHomoLinear(JitFPHomoLinear):
class JitFPHomoLinear(JitFPHomoLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class JitFPUniformLinear(JitFPUniformLinear):
class JitFPUniformLinear(JitFPUniformLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class JitFPNormalLinear(JitFPNormalLinear):
class JitFPNormalLinear(JitFPNormalLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1240,7 +1240,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPHomoLinear(JitFPHomoLinear):
class EventJitFPHomoLinear(JitFPHomoLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1319,7 +1319,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPUniformLinear(JitFPUniformLinear):
class EventJitFPUniformLinear(JitFPUniformLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1399,7 +1399,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPNormalLinear(JitFPNormalLinear):
class EventJitFPNormalLinear(JitFPNormalLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down

0 comments on commit 452b2a3

Please sign in to comment.