Skip to content

Commit

Permalink
[math] Add get JIT connect matrix methods for brainpy.dnn.linear
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 27, 2024
1 parent e3a854a commit dd084a7
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 5 deletions.
36 changes: 36 additions & 0 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,12 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPUniformLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
Expand Down Expand Up @@ -1138,6 +1144,12 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPNormalLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
Expand Down Expand Up @@ -1218,6 +1230,12 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class EventJitFPHomoLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
Expand Down Expand Up @@ -1297,6 +1315,12 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class EventJitFPUniformLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
Expand Down Expand Up @@ -1377,6 +1401,12 @@ def _batch_mv(self, x):
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class EventJitFPNormalLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
Expand Down Expand Up @@ -1456,3 +1486,9 @@ def _batch_mv(self, x):
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)

def get_connect_matrix(self):
return bm.jitconn.get_connect_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
111 changes: 111 additions & 0 deletions brainpy/_src/math/jitconn/matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'mv_prob_homo',
'mv_prob_uniform',
'mv_prob_normal',
'get_connect_matrix',
]


Expand Down Expand Up @@ -257,6 +258,49 @@ def mv_prob_normal(
transpose=transpose, outdim_parallel=outdim_parallel)[0]


def get_connect_matrix(
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`.
Parameters
----------
conn_prob: float
The connection probability.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.
Returns
-------
out: Array, ndarray
The connection matrix :math:`M`.
"""
if ti is None:
raise PackageMissingError.by_purpose('taichi', purpose='customized operators')

conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
return raw_get_connect_matrix(conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_)


def raw_mv_prob_homo(
vector: jax.Array,
weight: jax.Array, # vector with size 1
Expand Down Expand Up @@ -342,6 +386,28 @@ def raw_mv_prob_normal(
outdim_parallel=outdim_parallel)


def raw_get_connect_matrix(
conn_len: jax.Array,
seed: jax.Array,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
out_shape = shape if not transpose else (shape[1], shape[0])
if outdim_parallel:
prim = _get_connect_matrix_outdim_parallel_p
else:
prim = _get_connect_matrix_p

return prim(conn_len,
seed,
outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=jnp.int32)],
shape=shape,
transpose=transpose,
outdim_parallel=outdim_parallel)


def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights):
if vector.ndim != 1:
raise ValueError('vector should be a 1D vector.')
Expand Down Expand Up @@ -918,3 +984,48 @@ def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel):
cpu_kernel=_mv_prob_normal_cpu,
gpu_kernel=_mv_prob_normal_gpu
)


@ti.kernel
def _get_connect_matrix(
clen: ti.types.ndarray(),
seed: ti.types.ndarray(),
out: ti.types.ndarray(),
):
num_row = out.shape[0]
num_col = out.shape[1]
clen0 = clen[0]
seed0 = seed[0]

for i_col in range(num_col):
key = lfsr88_key(seed0 + i_col)
key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
while i_row < num_row:
out[i_row, i_col] = 1
key, inc = lfsr88_random_integers(key, 1, clen0)
i_row += inc


@ti.kernel
def _get_connect_matrix_outdim_parallel(
clen: ti.types.ndarray(),
seed: ti.types.ndarray(),
out: ti.types.ndarray(),
):
num_row = out.shape[0]
num_col = out.shape[1]
clen0 = clen[0]
seed0 = seed[0]

for i_row in range(num_row):
key = lfsr88_key(seed0 + i_row)
key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
while i_col < num_col:
out[i_row, i_col] = 1
key, inc = lfsr88_random_integers(key, 1, clen0)
i_col += inc


_get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix)
_get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel,
gpu_kernel=_get_connect_matrix_outdim_parallel)
46 changes: 46 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

import platform

force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

shapes = [(100, 200), (1000, 10)]


# SEED = 1234

class TestGetConnectMatrix(parameterized.TestCase):
def __init__(self, *args, platform='cpu', **kwargs):
super(TestGetConnectMatrix, self).__init__(*args, **kwargs)
bm.set_platform(platform)
print()

@parameterized.product(
transpose=[True, False],
outdim_parallel=[True, False],
shape=shapes,
prob=[0.1],
)
def test_get_connect_matrix(self, transpose, outdim_parallel, shape, prob):
print(
f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}')
conn = bm.jitconn.get_connect_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
shape = (shape[1], shape[0]) if transpose else shape
assert conn.shape == shape
assert conn.dtype == jnp.bool_
# sum all true values
# assert jnp.sum(conn) == jnp.round(prob * shape[0] * shape[1])
print(
f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}')
# print(f'conn: {conn}')
13 changes: 8 additions & 5 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .cupy_based import (
register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .cupy_based import (
register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

Expand Down Expand Up @@ -116,7 +118,8 @@ def __init__(
register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
if not gpu_checked:
raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')
raise ValueError(
f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')

# batching rule
if batching_translation is None:
Expand Down
2 changes: 2 additions & 0 deletions brainpy/math/jitconn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@
mv_prob_homo as mv_prob_homo,
mv_prob_uniform as mv_prob_uniform,
mv_prob_normal as mv_prob_normal,

get_connect_matrix as get_connect_matrix,
)

0 comments on commit dd084a7

Please sign in to comment.