Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 1, 2024
1 parent 939f030 commit 3908afe
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,9 +988,9 @@ def get_conn_matrix(self):
class JitFPHomoLayer(JitLinear):
def get_conn_matrix(self):
return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPUniformLayer(JitLinear):
Expand Down
12 changes: 6 additions & 6 deletions brainpy/_src/math/jitconn/matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from jax.interpreters import ad

from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.defaults import float_
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array, _get_dtype
from brainpy._src.math.op_register import XLACustomOp
from brainpy.errors import PackageMissingError
from brainpy._src.math import defauts

ti = import_taichi(error_if_not_found=False)

Expand Down Expand Up @@ -292,7 +292,7 @@ def get_homo_weight_matrix(
The connection matrix :math:`M`.
"""
if isinstance(weight, numbers.Number):
weight = jnp.atleast_1d(jnp.asarray(weight, dtype=float_))
weight = jnp.atleast_1d(jnp.asarray(weight, dtype=defauts.float_))
else:
raise ValueError(f'weight must be a number type, but get {type(weight)}')
if ti is None:
Expand Down Expand Up @@ -1221,7 +1221,7 @@ def _get_uniform_weight_matrix(
key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
while i_row < num_row:
key, raw_v = lfsr88_uniform(key, w_low0, w_high0)
out[i_row, i_col] += raw_v
out[i_row, i_col] = raw_v
key, inc = lfsr88_random_integers(key, 1, clen0)
i_row += inc

Expand All @@ -1246,7 +1246,7 @@ def _get_uniform_weight_matrix_outdim_parallel(
key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
while i_col < num_col:
key, raw_v = lfsr88_uniform(key, w_low0, w_high0)
out[i_row, i_col] += raw_v
out[i_row, i_col] = raw_v
key, inc = lfsr88_random_integers(key, 1, clen0)
i_col += inc

Expand Down Expand Up @@ -1277,7 +1277,7 @@ def _get_normal_weight_matrix(
key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
while i_row < num_row:
key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0)
out[i_row, i_col] += raw_v
out[i_row, i_col] = raw_v
key, inc = lfsr88_random_integers(key, 1, clen0)
i_row += inc

Expand All @@ -1302,7 +1302,7 @@ def _get_normal_weight_matrix_outdim_parallel(
key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
while i_col < num_col:
key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0)
out[i_row, i_col] += raw_v
out[i_row, i_col] = raw_v
key, inc = lfsr88_random_integers(key, 1, clen0)
i_col += inc

Expand Down

0 comments on commit 3908afe

Please sign in to comment.