From 452b2a3a9ad1bca6530f5102c717addd6d3f62d4 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 30 May 2024 12:40:45 +0800 Subject: [PATCH] Update linear.py --- brainpy/_src/dnn/linear.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 03398766..729fdc70 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -980,28 +980,28 @@ def __init__( self.sharding = sharding -class JitFPHomoLinear(Layer): +class JitFPHomoLayer(Layer): def get_conn_matrix(self): return bm.jitconn.get_conn_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPUniformLinear(Layer): +class JitFPUniformLayer(Layer): def get_conn_matrix(self): return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPNormalLinear(Layer): +class JitFPNormalLayer(Layer): def get_conn_matrix(self): return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPHomoLinear(JitFPHomoLinear): +class JitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1080,7 +1080,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPUniformLinear(JitFPUniformLinear): +class JitFPUniformLinear(JitFPUniformLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1160,7 +1160,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPNormalLinear(JitFPNormalLinear): +class JitFPNormalLinear(JitFPNormalLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1240,7 +1240,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPHomoLinear(JitFPHomoLinear): +class EventJitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1319,7 +1319,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPUniformLinear(JitFPUniformLinear): +class EventJitFPUniformLinear(JitFPUniformLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1399,7 +1399,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPNormalLinear(JitFPNormalLinear): +class EventJitFPNormalLinear(JitFPNormalLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: