Skip to content

Commit

Permalink
[math] Add get JIT weight matrix methods(Uniform & Normal) for `brain…
Browse files Browse the repository at this point in the history
…py.dnn.linear`
  • Loading branch information
Routhleck committed May 29, 2024
1 parent 3f46507 commit 0e6dbc2
Show file tree
Hide file tree
Showing 4 changed files with 426 additions and 18 deletions.
27 changes: 20 additions & 7 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,15 +980,28 @@ def __init__(
self.sharding = sharding


class JitFPLinear(Layer):
class JitFPHomoLinear(Layer):
def get_conn_matrix(self):
return bm.jitconn.get_conn_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

class JitFPUniformLinear(Layer):
def get_conn_matrix(self):
return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

class JitFPNormalLinear(Layer):
def get_conn_matrix(self):
return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

class JitFPHomoLinear(JitFPLinear):
class JitFPHomoLinear(JitFPHomoLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1067,7 +1080,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class JitFPUniformLinear(JitFPLinear):
class JitFPUniformLinear(JitFPUniformLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1147,7 +1160,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class JitFPNormalLinear(JitFPLinear):
class JitFPNormalLinear(JitFPNormalLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1227,7 +1240,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPHomoLinear(JitFPLinear):
class EventJitFPHomoLinear(JitFPHomoLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1306,7 +1319,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPUniformLinear(JitFPLinear):
class EventJitFPUniformLinear(JitFPUniformLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down Expand Up @@ -1386,7 +1399,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPNormalLinear(JitFPLinear):
class EventJitFPNormalLinear(JitFPNormalLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
Expand Down
284 changes: 280 additions & 4 deletions brainpy/_src/math/jitconn/matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
'mv_prob_uniform',
'mv_prob_normal',
'get_conn_matrix',
'get_uniform_weight_matrix',
'get_normal_weight_matrix'
]


Expand Down Expand Up @@ -297,8 +299,121 @@ def get_conn_matrix(
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
return raw_get_connect_matrix(conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_)
r = raw_get_connect_matrix(conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_)
if transpose:
return r.transpose()
else:
return r


def get_uniform_weight_matrix(
w_low: float,
w_high: float,
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the weight matrix :math:`M` with a uniform distribution for its value.
Parameters
----------
w_low: float
Lower boundary of the output interval.
w_high: float
Upper boundary of the output interval.
conn_prob: float
The connection probability.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The weight matrix :math:`M`.
"""
if ti is None:
raise PackageMissingError.by_purpose('taichi', purpose='customized operators')

w_low = jnp.atleast_1d(as_jax(w_low))
w_high = jnp.atleast_1d(as_jax(w_high))
conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_uniform_weight_matrix(w_low, w_high, conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]
if transpose:
return r.transpose()
else:
return r


def get_normal_weight_matrix(
w_mu: float,
w_sigma: float,
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the weight matrix :math:`M` with a normal distribution for its value.
Parameters
----------
w_mu: float
Mean (centre) of the distribution.
w_sigma: float
Standard deviation (spread or “width”) of the distribution. Must be non-negative.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The weight matrix :math:`M`.
"""
if ti is None:
raise PackageMissingError.by_purpose('taichi', purpose='customized operators')

w_mu = jnp.atleast_1d(as_jax(w_mu))
w_sigma = jnp.atleast_1d(as_jax(w_sigma))
conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
r = raw_get_normal_weight_matrix(w_mu, w_sigma, conn_len, seed,
shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0]
if transpose:
return r.transpose()
else:
return r


def raw_mv_prob_homo(
Expand Down Expand Up @@ -394,15 +509,64 @@ def raw_get_connect_matrix(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
out_shape = shape if not transpose else (shape[1], shape[0])
if outdim_parallel:
prim = _get_connect_matrix_outdim_parallel_p
else:
prim = _get_connect_matrix_p

return prim(conn_len,
seed,
outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=jnp.int32)],
outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.int32)],
shape=shape,
transpose=transpose,
outdim_parallel=outdim_parallel)


def raw_get_uniform_weight_matrix(
w_low: jax.Array,
w_high: jax.Array,
conn_len: jax.Array,
seed: jax.Array,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
if outdim_parallel:
prim = _get_uniform_weight_matrix_outdim_parallel_p
else:
prim = _get_uniform_weight_matrix_p

return prim(w_low,
w_high,
conn_len,
seed,
outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)],
shape=shape,
transpose=transpose,
outdim_parallel=outdim_parallel)


def raw_get_normal_weight_matrix(
w_mu: jax.Array,
w_sigma: jax.Array,
conn_len: jax.Array,
seed: jax.Array,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
if outdim_parallel:
prim = _get_normal_weight_matrix_outdim_parallel_p
else:
prim = _get_normal_weight_matrix_p

return prim(w_mu,
w_sigma,
conn_len,
seed,
outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)],
shape=shape,
transpose=transpose,
outdim_parallel=outdim_parallel)
Expand Down Expand Up @@ -1029,3 +1193,115 @@ def _get_connect_matrix_outdim_parallel(
_get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix)
_get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel,
gpu_kernel=_get_connect_matrix_outdim_parallel)


@ti.kernel
def _get_uniform_weight_matrix(
w_low: ti.types.ndarray(),
w_high: ti.types.ndarray(),
clen: ti.types.ndarray(),
seed: ti.types.ndarray(),
out: ti.types.ndarray(),
):
num_row = out.shape[0]
num_col = out.shape[1]
w_low0 = w_low[0]
w_high0 = w_high[0]
clen0 = clen[0]
seed0 = seed[0]

for i_col in range(num_col):
key = lfsr88_key(seed0 + i_col)
key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
while i_row < num_row:
key, raw_v = lfsr88_uniform(key, w_low0, w_high0)
out[i_row, i_col] += raw_v
key, inc = lfsr88_random_integers(key, 1, clen0)
i_row += inc


@ti.kernel
def _get_uniform_weight_matrix_outdim_parallel(
w_low: ti.types.ndarray(),
w_high: ti.types.ndarray(),
clen: ti.types.ndarray(),
seed: ti.types.ndarray(),
out: ti.types.ndarray(),
):
num_row = out.shape[0]
num_col = out.shape[1]
w_low0 = w_low[0]
w_high0 = w_high[0]
clen0 = clen[0]
seed0 = seed[0]

for i_row in range(num_row):
key = lfsr88_key(seed0 + i_row)
key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
while i_col < num_col:
key, raw_v = lfsr88_uniform(key, w_low0, w_high0)
out[i_row, i_col] += raw_v
key, inc = lfsr88_random_integers(key, 1, clen0)
i_col += inc


_get_uniform_weight_matrix_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix,
gpu_kernel=_get_uniform_weight_matrix)
_get_uniform_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix_outdim_parallel,
gpu_kernel=_get_uniform_weight_matrix_outdim_parallel)


@ti.kernel
def _get_normal_weight_matrix(
w_mu: ti.types.ndarray(),
w_sigma: ti.types.ndarray(),
clen: ti.types.ndarray(),
seed: ti.types.ndarray(),
out: ti.types.ndarray(),
):
num_row = out.shape[0]
num_col = out.shape[1]
w_mu0 = w_mu[0]
w_sigma0 = w_sigma[0]
clen0 = clen[0]
seed0 = seed[0]

for i_col in range(num_col):
key = lfsr88_key(seed0 + i_col)
key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
while i_row < num_row:
key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0)
out[i_row, i_col] += raw_v
key, inc = lfsr88_random_integers(key, 1, clen0)
i_row += inc


@ti.kernel
def _get_normal_weight_matrix_outdim_parallel(
w_mu: ti.types.ndarray(),
w_sigma: ti.types.ndarray(),
clen: ti.types.ndarray(),
seed: ti.types.ndarray(),
out: ti.types.ndarray(),
):
num_row = out.shape[0]
num_col = out.shape[1]
w_mu0 = w_mu[0]
w_sigma0 = w_sigma[0]
clen0 = clen[0]
seed0 = seed[0]

for i_row in range(num_row):
key = lfsr88_key(seed0 + i_row)
key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
while i_col < num_col:
key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0)
out[i_row, i_col] += raw_v
key, inc = lfsr88_random_integers(key, 1, clen0)
i_col += inc


_get_normal_weight_matrix_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix,
gpu_kernel=_get_normal_weight_matrix)
_get_normal_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix_outdim_parallel,
gpu_kernel=_get_normal_weight_matrix_outdim_parallel)
Loading

0 comments on commit 0e6dbc2

Please sign in to comment.