Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 1, 2024
1 parent 452b2a3 commit 8defda9
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
9 changes: 6 additions & 3 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,22 +979,25 @@ def __init__(
self.conn = conn
self.sharding = sharding

class JitLinear(Layer):
def get_conn_matrix(self):
pass

class JitFPHomoLayer(Layer):
class JitFPHomoLayer(JitLinear):
def get_conn_matrix(self):
return bm.jitconn.get_conn_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

class JitFPUniformLayer(Layer):
class JitFPUniformLayer(JitLinear):
def get_conn_matrix(self):
return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

class JitFPNormalLayer(Layer):
class JitFPNormalLayer(JitLinear):
def get_conn_matrix(self):
return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
Expand Down
20 changes: 13 additions & 7 deletions brainpy/_src/math/jitconn/matvec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-


import numbers
from typing import Tuple, Optional, Union

import jax
Expand All @@ -9,6 +8,7 @@
from jax.interpreters import ad

from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.defaults import float_
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array, _get_dtype
from brainpy._src.math.op_register import XLACustomOp
Expand All @@ -20,7 +20,7 @@
'mv_prob_homo',
'mv_prob_uniform',
'mv_prob_normal',
'get_conn_matrix',
'get_homo_weight_matrix',
'get_uniform_weight_matrix',
'get_normal_weight_matrix'
]
Expand Down Expand Up @@ -260,7 +260,8 @@ def mv_prob_normal(
transpose=transpose, outdim_parallel=outdim_parallel)[0]


def get_conn_matrix(
def get_homo_weight_matrix(
weight: float,
conn_prob: float,
seed: Optional[int] = None,
*,
Expand Down Expand Up @@ -290,6 +291,10 @@ def get_conn_matrix(
out: Array, ndarray
The connection matrix :math:`M`.
"""
if isinstance(weight, numbers.Number):
weight = jnp.atleast_1d(jnp.asarray(weight, dtype=float_))
else:
raise ValueError(f'weight must be a number type, but get {type(weight)}')
if ti is None:
raise PackageMissingError.by_purpose('taichi', purpose='customized operators')

Expand All @@ -299,8 +304,9 @@ def get_conn_matrix(
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_connect_matrix(conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_)
r = raw_get_homo_weight_matrix(conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_)
r *= weight
if transpose:
return r.transpose()
else:
Expand Down Expand Up @@ -501,7 +507,7 @@ def raw_mv_prob_normal(
outdim_parallel=outdim_parallel)


def raw_get_connect_matrix(
def raw_get_homo_weight_matrix(
conn_len: jax.Array,
seed: jax.Array,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,20 @@ def __init__(self, *args, platform='cpu', **kwargs):
prob=[0.1],
)
def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob):
homo_data = 1.
print(
f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}')
conn = bm.jitconn.get_conn_matrix(prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
conn = bm.jitconn.get_homo_weight_matrix(homo_data, prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
shape = (shape[1], shape[0]) if transpose else shape
print(conn.shape)
assert conn.shape == shape
assert conn.dtype == jnp.bool_
# assert conn.dtype == jnp.float_
# sum all true values
print(
f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}')

# compare with jitconn op
homo_data = 1.

rng = bm.random.RandomState()
vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1]))

Expand Down
2 changes: 1 addition & 1 deletion brainpy/math/jitconn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
mv_prob_uniform as mv_prob_uniform,
mv_prob_normal as mv_prob_normal,

get_conn_matrix as get_conn_matrix,
get_homo_weight_matrix as get_homo_weight_matrix,
get_uniform_weight_matrix as get_uniform_weight_matrix,
get_normal_weight_matrix as get_normal_weight_matrix,
)
Expand Down

0 comments on commit 8defda9

Please sign in to comment.