Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Add getting JIT connect matrix method for brainpy.dnn.linear #672

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,15 @@ def __init__(
self.sharding = sharding


class JitFPHomoLinear(Layer):
class JitFPLinear(Layer):
def get_conn_matrix(self):
return bm.jitconn.get_conn_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPHomoLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1059,7 +1067,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class JitFPUniformLinear(Layer):
class JitFPUniformLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1139,7 +1147,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class JitFPNormalLinear(Layer):
class JitFPNormalLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1219,7 +1227,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPHomoLinear(Layer):
class EventJitFPHomoLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1298,7 +1306,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPUniformLinear(Layer):
class EventJitFPUniformLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1378,7 +1386,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPNormalLinear(Layer):
class EventJitFPNormalLinear(JitFPLinear):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down
111 changes: 111 additions & 0 deletions brainpy/_src/math/jitconn/matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'mv_prob_homo',
'mv_prob_uniform',
'mv_prob_normal',
'get_conn_matrix',
]


Expand Down Expand Up @@ -257,6 +258,49 @@ def mv_prob_normal(
transpose=transpose, outdim_parallel=outdim_parallel)[0]


def get_conn_matrix(
conn_prob: float,
seed: Optional[int] = None,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`.

Parameters
----------
conn_prob: float
The connection probability.
shape: tuple of int
The matrix shape.
seed: int
The random number generation seed.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
Perform the parallel random generations along the out dimension or not.
It can be used to set the just-in-time generated :math:M^T: is the same
as the just-in-time generated :math:`M` when ``transpose=True``.

Returns
-------
out: Array, ndarray
The connection matrix :math:`M`.
"""
if ti is None:
raise PackageMissingError.by_purpose('taichi', purpose='customized operators')

conn_len = jnp.ceil(1 / conn_prob) * 2 - 1
conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)
if seed is None:
with jax.ensure_compile_time_eval():
seed = np.random.randint(0, int(1e8), 1)
seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32))
return raw_get_connect_matrix(conn_len, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_)


def raw_mv_prob_homo(
vector: jax.Array,
weight: jax.Array, # vector with size 1
Expand Down Expand Up @@ -342,6 +386,28 @@ def raw_mv_prob_normal(
outdim_parallel=outdim_parallel)


def raw_get_connect_matrix(
conn_len: jax.Array,
seed: jax.Array,
*,
shape: Tuple[int, int],
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
out_shape = shape if not transpose else (shape[1], shape[0])
if outdim_parallel:
prim = _get_connect_matrix_outdim_parallel_p
else:
prim = _get_connect_matrix_p

return prim(conn_len,
seed,
outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=jnp.int32)],
shape=shape,
transpose=transpose,
outdim_parallel=outdim_parallel)


def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights):
if vector.ndim != 1:
raise ValueError('vector should be a 1D vector.')
Expand Down Expand Up @@ -918,3 +984,48 @@ def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel):
cpu_kernel=_mv_prob_normal_cpu,
gpu_kernel=_mv_prob_normal_gpu
)


@ti.kernel
def _get_connect_matrix(
clen: ti.types.ndarray(),
seed: ti.types.ndarray(),
out: ti.types.ndarray(),
):
num_row = out.shape[0]
num_col = out.shape[1]
clen0 = clen[0]
seed0 = seed[0]

for i_col in range(num_col):
key = lfsr88_key(seed0 + i_col)
key, i_row = lfsr88_random_integers(key, 0, clen0 - 1)
while i_row < num_row:
out[i_row, i_col] = 1
key, inc = lfsr88_random_integers(key, 1, clen0)
i_row += inc


@ti.kernel
def _get_connect_matrix_outdim_parallel(
clen: ti.types.ndarray(),
seed: ti.types.ndarray(),
out: ti.types.ndarray(),
):
num_row = out.shape[0]
num_col = out.shape[1]
clen0 = clen[0]
seed0 = seed[0]

for i_row in range(num_row):
key = lfsr88_key(seed0 + i_row)
key, i_col = lfsr88_random_integers(key, 0, clen0 - 1)
while i_col < num_col:
out[i_row, i_col] = 1
key, inc = lfsr88_random_integers(key, 1, clen0)
i_col += inc


_get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix)
_get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel,
gpu_kernel=_get_connect_matrix_outdim_parallel)
46 changes: 46 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

import platform

force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

shapes = [(100, 200), (1000, 10)]


# SEED = 1234

class TestGetConnectMatrix(parameterized.TestCase):
def __init__(self, *args, platform='cpu', **kwargs):
super(TestGetConnectMatrix, self).__init__(*args, **kwargs)
bm.set_platform(platform)
print()

@parameterized.product(
transpose=[True, False],
outdim_parallel=[True, False],
shape=shapes,
prob=[0.1],
)
def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob):
print(
f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}')
conn = bm.jitconn.get_conn_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
shape = (shape[1], shape[0]) if transpose else shape
assert conn.shape == shape
assert conn.dtype == jnp.bool_
# sum all true values
# assert jnp.sum(conn) == jnp.round(prob * shape[0] * shape[1])
print(
f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}')
# print(f'conn: {conn}')
13 changes: 8 additions & 5 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .cupy_based import (
register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .cupy_based import (
register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

Expand Down Expand Up @@ -116,7 +118,8 @@ def __init__(
register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
if not gpu_checked:
raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')
raise ValueError(
f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')

# batching rule
if batching_translation is None:
Expand Down
2 changes: 2 additions & 0 deletions brainpy/math/jitconn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@
mv_prob_homo as mv_prob_homo,
mv_prob_uniform as mv_prob_uniform,
mv_prob_normal as mv_prob_normal,

get_conn_matrix as get_conn_matrix,
)

Loading