Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 29, 2024
2 parents 1d83b77 + 6165a95 commit cbfc66e
Show file tree
Hide file tree
Showing 41 changed files with 1,651 additions and 1,670 deletions.
285 changes: 71 additions & 214 deletions brainpy/_src/connect/random_conn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-

from functools import partial
from typing import Optional

Expand All @@ -9,10 +10,8 @@
from brainpy.errors import ConnectorError
from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed
from brainpy._src.tools.package import SUPPORT_NUMBA
from brainpy._src.dependency_check import import_numba
from .base import *

numba = import_numba(error_if_not_found=False)

__all__ = [
'FixedProb',
Expand Down Expand Up @@ -1099,192 +1098,69 @@ def __init__(self, dist=1, prob=1., pre_ratio=1., seed=None, include_self=True,

rng = np.random if SUPPORT_NUMBA else self.rng

# @njit(parallel=True)
# def _connect_1d_jit_parallel(pre_pos, pre_size, post_size, n_dim):
# all_post_ids = np.zeros(post_size[0], dtype=get_idx_type())
# all_pre_ids = np.zeros(post_size[0], dtype=get_idx_type())
# size = 0
#
# if rng.random() < pre_ratio:
# normalized_pos = np.zeros(n_dim)
# for i in prange(n_dim): # Use prange for potential parallelism
# pre_len = pre_size[i]
# post_len = post_size[i]
# normalized_pos[i] = pre_pos[i] * post_len / pre_len
# for i in prange(post_size[0]):
# post_pos = np.asarray((i,))
# d = np.abs(pre_pos[0] - post_pos[0]) # Adjust the distance calculation
# if d <= dist:
# if d == 0. and not include_self:
# continue
# if rng.random() <= prob:
# all_post_ids[size] = pos2ind(post_pos, post_size)
# all_pre_ids[size] = pos2ind(pre_pos, pre_size)
# size += 1
# return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays

if numba is not None:
from numba import njit
@njit
def _connect_1d_jit(pre_pos, pre_size, post_size, n_dim):
all_post_ids = np.zeros(post_size[0], dtype=IDX_DTYPE)
all_pre_ids = np.zeros(post_size[0], dtype=IDX_DTYPE)
size = 0

if rng.random() < pre_ratio:
normalized_pos = np.zeros(n_dim)
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos[i] = pre_pos[i] * post_len / pre_len
for i in range(post_size[0]):
post_pos = np.asarray((i,))
d = np.abs(pre_pos[0] - post_pos[0])
if d <= dist:
if d == 0. and not include_self:
continue
if rng.random() <= prob:
all_post_ids[size] = pos2ind(post_pos, post_size)
all_pre_ids[size] = pos2ind(pre_pos, pre_size)
size += 1
return all_pre_ids[:size], all_post_ids[:size]

@njit
def _connect_2d_jit(pre_pos, pre_size, post_size, n_dim):
max_size = post_size[0] * post_size[1]
all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE)
all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE)
size = 0

if rng.random() < pre_ratio:
normalized_pos = np.zeros(n_dim)
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos[i] = pre_pos[i] * post_len / pre_len
for i in range(post_size[0]):
for j in range(post_size[1]):
post_pos = np.asarray((i, j))
d = np.sqrt(np.sum(np.square(pre_pos - post_pos)))
if d <= dist:
if d == 0. and not include_self:
continue
if rng.random() <= prob:
all_post_ids[size] = pos2ind(post_pos, post_size)
all_pre_ids[size] = pos2ind(pre_pos, pre_size)
size += 1
return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays

@njit
def _connect_3d_jit(pre_pos, pre_size, post_size, n_dim):
max_size = post_size[0] * post_size[1] * post_size[2]
all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE)
all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE)
size = 0

if rng.random() < pre_ratio:
normalized_pos = np.zeros(n_dim)
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos[i] = pre_pos[i] * post_len / pre_len
for i in range(post_size[0]):
for j in range(post_size[1]):
for k in range(post_size[2]):
post_pos = np.asarray((i, j, k))
d = np.sqrt(np.sum(np.square(pre_pos - post_pos)))
if d <= dist:
if d == 0. and not include_self:
continue
if rng.random() <= prob:
all_post_ids[size] = pos2ind(post_pos, post_size)
all_pre_ids[size] = pos2ind(pre_pos, pre_size)
size += 1
return all_pre_ids[:size], all_post_ids[:size]

@njit
def _connect_4d_jit(pre_pos, pre_size, post_size, n_dim):
max_size = post_size[0] * post_size[1] * post_size[2] * post_size[3]
all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE)
all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE)
size = 0

if rng.random() < pre_ratio:
normalized_pos = np.zeros(n_dim)
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos[i] = pre_pos[i] * post_len / pre_len
for i in range(post_size[0]):
for j in range(post_size[1]):
for k in range(post_size[2]):
for l in range(post_size[3]):
post_pos = np.asarray((i, j, k, l))
d = np.sqrt(np.sum(np.square(pre_pos - post_pos)))
if d <= dist:
if d == 0. and not include_self:
continue
if rng.random() <= prob:
all_post_ids[size] = pos2ind(post_pos, post_size)
all_pre_ids[size] = pos2ind(pre_pos, pre_size)
size += 1
return all_pre_ids[:size], all_post_ids[:size]

self._connect_1d_jit = _connect_1d_jit
self._connect_2d_jit = _connect_2d_jit
self._connect_3d_jit = _connect_3d_jit
self._connect_4d_jit = _connect_4d_jit

def _connect_1d(pre_pos, pre_size, post_size, n_dim):
all_post_ids = []
all_pre_ids = []
@numba_jit
def _connect_1d_jit(pre_pos, pre_size, post_size, n_dim):
all_post_ids = np.zeros(post_size[0], dtype=IDX_DTYPE)
all_pre_ids = np.zeros(post_size[0], dtype=IDX_DTYPE)
size = 0

if rng.random() < pre_ratio:
normalized_pos = []
normalized_pos = np.zeros(n_dim)
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos.append(pre_pos[i] * post_len / pre_len)
normalized_pos[i] = pre_pos[i] * post_len / pre_len
for i in range(post_size[0]):
post_pos = np.asarray((i,))
d = np.sum(np.abs(pre_pos - post_pos))
d = np.abs(pre_pos[0] - post_pos[0])
if d <= dist:
if d == 0. and not include_self:
continue
if rng.random() <= prob:
all_post_ids.append(pos2ind(post_pos, post_size))
all_pre_ids.append(pos2ind(pre_pos, pre_size))
return all_pre_ids, all_post_ids
all_post_ids[size] = pos2ind(post_pos, post_size)
all_pre_ids[size] = pos2ind(pre_pos, pre_size)
size += 1
return all_pre_ids[:size], all_post_ids[:size]

@numba_jit
def _connect_2d_jit(pre_pos, pre_size, post_size, n_dim):
max_size = post_size[0] * post_size[1]
all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE)
all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE)
size = 0

def _connect_2d(pre_pos, pre_size, post_size, n_dim):
all_post_ids = []
all_pre_ids = []
if rng.random() < pre_ratio:
normalized_pos = []
normalized_pos = np.zeros(n_dim)
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos.append(pre_pos[i] * post_len / pre_len)
normalized_pos[i] = pre_pos[i] * post_len / pre_len
for i in range(post_size[0]):
for j in range(post_size[1]):
post_pos = np.asarray((i, j))
d = np.sqrt(np.sum(np.square(pre_pos - post_pos)))
if d <= dist:
if d == 0. and not include_self:
continue
if np.random.random() <= prob:
all_post_ids.append(pos2ind(post_pos, post_size))
all_pre_ids.append(pos2ind(pre_pos, pre_size))
return all_pre_ids, all_post_ids

def _connect_3d(pre_pos, pre_size, post_size, n_dim):
all_post_ids = []
all_pre_ids = []
if rng.random() <= prob:
all_post_ids[size] = pos2ind(post_pos, post_size)
all_pre_ids[size] = pos2ind(pre_pos, pre_size)
size += 1
return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays

@numba_jit
def _connect_3d_jit(pre_pos, pre_size, post_size, n_dim):
max_size = post_size[0] * post_size[1] * post_size[2]
all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE)
all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE)
size = 0

if rng.random() < pre_ratio:
normalized_pos = []
normalized_pos = np.zeros(n_dim)
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos.append(pre_pos[i] * post_len / pre_len)
normalized_pos[i] = pre_pos[i] * post_len / pre_len
for i in range(post_size[0]):
for j in range(post_size[1]):
for k in range(post_size[2]):
Expand All @@ -1293,20 +1169,25 @@ def _connect_3d(pre_pos, pre_size, post_size, n_dim):
if d <= dist:
if d == 0. and not include_self:
continue
if np.random.random() <= prob:
all_post_ids.append(pos2ind(post_pos, post_size))
all_pre_ids.append(pos2ind(pre_pos, pre_size))
return all_pre_ids, all_post_ids

def _connect_4d(pre_pos, pre_size, post_size, n_dim):
all_post_ids = []
all_pre_ids = []
if rng.random() <= prob:
all_post_ids[size] = pos2ind(post_pos, post_size)
all_pre_ids[size] = pos2ind(pre_pos, pre_size)
size += 1
return all_pre_ids[:size], all_post_ids[:size]

@numba_jit
def _connect_4d_jit(pre_pos, pre_size, post_size, n_dim):
max_size = post_size[0] * post_size[1] * post_size[2] * post_size[3]
all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE)
all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE)
size = 0

if rng.random() < pre_ratio:
normalized_pos = []
normalized_pos = np.zeros(n_dim)
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos.append(pre_pos[i] * post_len / pre_len)
normalized_pos[i] = pre_pos[i] * post_len / pre_len
for i in range(post_size[0]):
for j in range(post_size[1]):
for k in range(post_size[2]):
Expand All @@ -1316,15 +1197,16 @@ def _connect_4d(pre_pos, pre_size, post_size, n_dim):
if d <= dist:
if d == 0. and not include_self:
continue
if np.random.random() <= prob:
all_post_ids.append(pos2ind(post_pos, post_size))
all_pre_ids.append(pos2ind(pre_pos, pre_size))
return all_pre_ids, all_post_ids
if rng.random() <= prob:
all_post_ids[size] = pos2ind(post_pos, post_size)
all_pre_ids[size] = pos2ind(pre_pos, pre_size)
size += 1
return all_pre_ids[:size], all_post_ids[:size]

self._connect_1d = numba_jit(_connect_1d)
self._connect_2d = numba_jit(_connect_2d)
self._connect_3d = numba_jit(_connect_3d)
self._connect_4d = numba_jit(_connect_4d)
self._connect_1d_jit = _connect_1d_jit
self._connect_2d_jit = _connect_2d_jit
self._connect_3d_jit = _connect_3d_jit
self._connect_4d_jit = _connect_4d_jit

def build_coo(self, isOptimized=True):
if len(self.pre_size) != len(self.post_size):
Expand All @@ -1336,41 +1218,16 @@ def build_coo(self, isOptimized=True):

# connections
n_dim = len(self.pre_size)
if not isOptimized:
if n_dim == 1:
f = self._connect_1d
elif n_dim == 2:
f = self._connect_2d
elif n_dim == 3:
f = self._connect_3d
elif n_dim == 4:
f = self._connect_4d
else:
raise NotImplementedError('Does not support the network dimension bigger than 4.')
if n_dim == 1:
f = self._connect_1d_jit
elif n_dim == 2:
f = self._connect_2d_jit
elif n_dim == 3:
f = self._connect_3d_jit
elif n_dim == 4:
f = self._connect_4d_jit
else:
if numba is None:
if n_dim == 1:
f = self._connect_1d
elif n_dim == 2:
f = self._connect_2d
elif n_dim == 3:
f = self._connect_3d
elif n_dim == 4:
f = self._connect_4d
else:
raise NotImplementedError('Does not support the network dimension bigger than 4.')
else:
if n_dim == 1:
f = self._connect_1d_jit
elif n_dim == 2:
f = self._connect_2d_jit
elif n_dim == 3:
f = self._connect_3d_jit
elif n_dim == 4:
f = self._connect_4d_jit
else:
raise NotImplementedError('Does not support the network dimension bigger than 4.')

raise NotImplementedError('Does not support the network dimension bigger than 4.')

pre_size = np.asarray(self.pre_size)
post_size = np.asarray(self.post_size)
Expand Down
Loading

0 comments on commit cbfc66e

Please sign in to comment.