From e17258f01fe824c32e455d8b77f2b0c50e6aa065 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 15:11:47 +0800 Subject: [PATCH] update tests --- .github/workflows/CI.yml | 197 --- brainpy/_src/connect/tests/test_all_time.py | 1272 +++++++++-------- brainpy/_src/math/event/__init__.py | 2 +- .../event/{_csr_matvec.py => csr_matvec.py} | 4 +- .../_src/math/event/tests/test_event_csrmv.py | 100 +- brainpy/_src/math/jitconn/__init__.py | 4 +- .../{_event_matvec.py => event_matvec.py} | 22 +- .../math/jitconn/{_matvec.py => matvec.py} | 0 .../math/jitconn/tests/test_event_matvec.py | 913 ++++++------ .../_src/math/jitconn/tests/test_matvec.py | 748 +++++----- brainpy/_src/math/sparse/__init__.py | 8 +- .../math/sparse/{_bsr_mm.py => bsr_mm.py} | 0 .../math/sparse/{_bsr_mv.py => bsr_mv.py} | 2 +- .../math/sparse/{_coo_mv.py => coo_mv.py} | 0 .../math/sparse/{_csr_mv.py => csr_mv.py} | 2 +- .../math/sparse/{_jax_prim.py => jax_prim.py} | 0 brainpy/_src/math/sparse/tests/test_csrmv.py | 33 +- .../_src/math/sparse/{_utils.py => utils.py} | 0 18 files changed, 1444 insertions(+), 1863 deletions(-) rename brainpy/_src/math/event/{_csr_matvec.py => csr_matvec.py} (99%) rename brainpy/_src/math/jitconn/{_event_matvec.py => event_matvec.py} (98%) rename brainpy/_src/math/jitconn/{_matvec.py => matvec.py} (100%) rename brainpy/_src/math/sparse/{_bsr_mm.py => bsr_mm.py} (100%) rename brainpy/_src/math/sparse/{_bsr_mv.py => bsr_mv.py} (99%) rename brainpy/_src/math/sparse/{_coo_mv.py => coo_mv.py} (100%) rename brainpy/_src/math/sparse/{_csr_mv.py => csr_mv.py} (99%) rename brainpy/_src/math/sparse/{_jax_prim.py => jax_prim.py} (100%) rename brainpy/_src/math/sparse/{_utils.py => utils.py} (100%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 95bd8eafd..66af94fce 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,71 +50,6 @@ jobs: cd brainpy pytest _src/ - test_linux_with_taichi_numba: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [ "3.9", "3.10", "3.11"] - - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd brainpy - pytest _src/ - - -# test_linux_py37: -# runs-on: ubuntu-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# test_macos: runs-on: macos-latest @@ -147,135 +82,3 @@ jobs: cd brainpy pytest _src/ - test_macos_with_taichi_numba: - runs-on: macos-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.10", "3.11"] - - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd brainpy - pytest _src/ - -# test_macos_py37: -# runs-on: macos-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: [ "3.7" ] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# - - -# test_windows: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.9", "3.10", "3.11"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install -r requirements-dev.txt -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd brainpy -# pytest _src/ - - -# test_windows_py37: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install numpy>=1.21.0 -# python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver -# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz -# python -m pip install -r requirements-dev.txt -# python -m pip install tqdm brainpylib -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py index b634d6dbe..07422f65e 100644 --- a/brainpy/_src/connect/tests/test_all_time.py +++ b/brainpy/_src/connect/tests/test_all_time.py @@ -1,18 +1,19 @@ import time from datetime import datetime -import brainpy as bp -import unittest import pytest +import brainpy as bp + +pytest.skip('skip.', allow_module_level=True) + try: - import pandas as pd + import pandas as pd - df = pd.DataFrame( - columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter', - 'time(ms)']) + df = pd.DataFrame(columns=['connector name', 'superclass', 'connect matrix size', + 'build function', 'other parameter', 'time(ms)']) except (ImportError, ModuleNotFoundError): - print('No pandas installed, skip test.') + print('No pandas installed, skip test.') # size_same = [100, 500, 2500, 12500, 25000, 37500, 50000] # size_same = [100, 500, 2500, 12500] @@ -21,644 +22,645 @@ size_same = [100, 500, 2500] size_diff = [(10, 100), (100, 1000)] + def get_ms(value): - return round(value * 1000, 4) + return round(value * 1000, 4) def insert_row(connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used): - try: - df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + try: + df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') class OneEndConnector(unittest.TestCase): - def test_gaussian_prob(self): - print() - for size in size_same: - print('GaussianProb:', size) - conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'sigma=1/include_self=False', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['GaussianProb', - # 'OneEndConnector', - # f'{size}x{size}', - # 'build_coo', - # 'sigma=1/include_self=False', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'sigma=1/include_self=False', - time_used) - - def test_grid_four(self): - print() - for size in size_same: - print('GridFour:', size) - conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_eight(self): - print() - for size in size_same: - print('GridEight:', size) - conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_n(self): - print() - for size in size_same: - print('GridN:', size) - conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False/N=2', - time_used) + def test_gaussian_prob(self): + print() + for size in size_same: + print('GaussianProb:', size) + conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'sigma=1/include_self=False', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['GaussianProb', + # 'OneEndConnector', + # f'{size}x{size}', + # 'build_coo', + # 'sigma=1/include_self=False', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'sigma=1/include_self=False', + time_used) + + def test_grid_four(self): + print() + for size in size_same: + print('GridFour:', size) + conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_eight(self): + print() + for size in size_same: + print('GridEight:', size) + conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_n(self): + print() + for size in size_same: + print('GridN:', size) + conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False/N=2', + time_used) class TwoEndConnector(unittest.TestCase): - def test_fixed_prob(self): - print() - for size in size_same: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'prob=0.1', - time_used) - - for size in size_diff: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'prob=0.1', - time_used) - - def test_fixed_pre_num(self): - print() - for size in size_same: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'pre_num=10', - time_used) - - for size in size_diff: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_fixed_post_num(self): - print() - for size in size_same: - print('FixedPostNum:', size) - conn = bp.connect.FixedPostNum(num=10, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - mat = conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num=10', - time_used) - - for size in size_diff: - print('FixedPostNum:', size) - conn = bp.connect.FixedPreNum(num=10, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_prob_dist(self): - print() - for size in size_same: - print('ProbDist:', size) - conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.5', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - def test_small_world(self): - print() - for size in size_same: - print('SmallWorld:', size) - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - def test_scale_free_ba(self): - print() - for size in size_same: - print('ScaleFreeBA:', size) - conn = bp.connect.ScaleFreeBA(m=2) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=2', - time_used) - - def test_scale_free_ba_dual(self): - print() - for size in size_same: - print('ScaleFreeBADual:', size) - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm1=2/m2=3/p=0.4', - time_used) - - def test_power_law(self): - print() - for size in size_same: - print('PowerLaw:', size) - conn = bp.connect.PowerLaw(m=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=3/p=0.4', - time_used) - - def test_one2one(self): - print() - for size in size_same: - print('One2One:', size) - conn = bp.connect.One2One() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - '', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) - - def test_all2all(self): - print() - for size in size_same: - print('All2All:', size) - conn = bp.connect.All2All() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['All2All', - # 'TwoEndConnector', - # f'{size}x{size}', - # 'build_coo', - # '', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) + def test_fixed_prob(self): + print() + for size in size_same: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'prob=0.1', + time_used) + + for size in size_diff: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'prob=0.1', + time_used) + + def test_fixed_pre_num(self): + print() + for size in size_same: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'pre_num=10', + time_used) + + for size in size_diff: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_fixed_post_num(self): + print() + for size in size_same: + print('FixedPostNum:', size) + conn = bp.connect.FixedPostNum(num=10, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + mat = conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num=10', + time_used) + + for size in size_diff: + print('FixedPostNum:', size) + conn = bp.connect.FixedPreNum(num=10, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_prob_dist(self): + print() + for size in size_same: + print('ProbDist:', size) + conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.5', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + def test_small_world(self): + print() + for size in size_same: + print('SmallWorld:', size) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + def test_scale_free_ba(self): + print() + for size in size_same: + print('ScaleFreeBA:', size) + conn = bp.connect.ScaleFreeBA(m=2) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=2', + time_used) + + def test_scale_free_ba_dual(self): + print() + for size in size_same: + print('ScaleFreeBADual:', size) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm1=2/m2=3/p=0.4', + time_used) + + def test_power_law(self): + print() + for size in size_same: + print('PowerLaw:', size) + conn = bp.connect.PowerLaw(m=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=3/p=0.4', + time_used) + + def test_one2one(self): + print() + for size in size_same: + print('One2One:', size) + conn = bp.connect.One2One() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + '', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) + + def test_all2all(self): + print() + for size in size_same: + print('All2All:', size) + conn = bp.connect.All2All() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['All2All', + # 'TwoEndConnector', + # f'{size}x{size}', + # 'build_coo', + # '', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) class TestSave(unittest.TestCase): - def test_save(self): - try: - df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', - index=False) - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + def test_save(self): + try: + df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', + index=False) + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index bdd3102a3..9ebad3e94 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,2 +1,2 @@ -from ._csr_matvec import * +from .csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py similarity index 99% rename from brainpy/_src/math/event/_csr_matvec.py rename to brainpy/_src/math/event/csr_matvec.py index 6b7f7da02..2d801ee7c 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -20,8 +20,8 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy.errors import PackageMissingError __all__ = [ diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 67e09d0a4..6c0a2ed47 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -9,13 +9,11 @@ import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) - seed = 1234 @@ -26,7 +24,6 @@ def func(*args, **kwargs): return func -taichi_csr_matvec = bm.event.csrmv class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -37,22 +34,22 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo(self, transpose, shape, homo_data): print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (events @ dense) if transpose else (dense @ events) - r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -60,23 +57,22 @@ def test_homo(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo_vmap(self, shape, transpose, homo_data): print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -84,14 +80,14 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'events' f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, + f4 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 @@ -102,16 +98,15 @@ def test_homo_vmap(self, shape, transpose, homo_data): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo_grad(self, shape, transpose, homo_data): print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -119,31 +114,26 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(bm.sparse.csrmv))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r1 = jax.grad(sum_op(bm.sparse.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(bm.event.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, + transpose=transpose) + r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, + transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) bm.clear_buffer_memory() @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000), ] + shape=[(100, 200), (10, 1000), ] ) def test_heter(self, shape, transpose): print(f'test_heter: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -151,9 +141,9 @@ def test_heter(self, shape, transpose): heter_data = bm.as_jax(rng.random(indices.shape)) r1 = bm.sparse.csrmv(heter_data, indices, indptr, events, + shape=shape, transpose=transpose) + r2 = bm.event.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = taichi_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -161,24 +151,21 @@ def test_heter(self, shape, transpose): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] + shape=[(100, 200), (10, 1000)] ) def test_heter_vmap(self, shape, transpose): print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -187,16 +174,16 @@ def test_heter_vmap(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, + f4 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), @@ -206,15 +193,12 @@ def test_heter_vmap(self, shape, transpose): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] + shape=[(100, 200), (10, 1000)] ) def test_heter_grad(self, shape, transpose): print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -226,20 +210,20 @@ def test_heter_grad(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) r1 = jax.grad(sum_op(bm.sparse.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( + r2 = jax.grad(sum_op(bm.event.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index 6f7cddf6a..bb6a3c1f4 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,2 +1,2 @@ -from ._matvec import * -from ._event_matvec import * +from .matvec import * +from .event_matvec import * diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py similarity index 98% rename from brainpy/_src/math/jitconn/_event_matvec.py rename to brainpy/_src/math/jitconn/event_matvec.py index 976b72b96..3e4048c88 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -8,17 +8,17 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.jitconn._matvec import (mv_prob_homo, - mv_prob_uniform, - mv_prob_normal, - _general_checking, - raw_mv_prob_homo, - raw_mv_prob_uniform, - raw_mv_prob_normal, - _mv_prob_homo_transpose, - _mv_prob_uniform_transpose, - _mv_prob_normal_transpose, - _reverse) +from brainpy._src.math.jitconn.matvec import (mv_prob_homo, + mv_prob_uniform, + mv_prob_normal, + _general_checking, + raw_mv_prob_homo, + raw_mv_prob_uniform, + raw_mv_prob_normal, + _mv_prob_homo_transpose, + _mv_prob_uniform_transpose, + _mv_prob_normal_transpose, + _reverse) from brainpy._src.math.ndarray import _get_dtype from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/matvec.py similarity index 100% rename from brainpy/_src/math/jitconn/_matvec.py rename to brainpy/_src/math/jitconn/matvec.py diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index d8e086540..6fb8d02ec 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from functools import partial import jax import jax.numpy as jnp -from absl.testing import parameterized import pytest +from absl.testing import parameterized import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -12,515 +11,419 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) - shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - -taichi_mv_prob_homo = bm.jitconn.event_mv_prob_homo -taichi_mv_prob_uniform = bm.jitconn.event_mv_prob_uniform -taichi_mv_prob_normal = bm.jitconn.event_mv_prob_normal +shapes = [(100, 200), (1000, 10)] class Test_event_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - homo_data=[-1., ], - bool_event=[True, False], - seed=[1234], - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): - print(f'_test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') - # indices = bm.as_jax(indices) - # indptr = bm.as_jax(indptr) - # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, - # shape=shape, transpose=transpose) - # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - bool_event=[True, False], - seed=[1234], + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + homo_data=[-1.], + bool_event=[True, False], + seed=[1234], + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): + print(f'_test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') + # indices = bm.as_jax(indices) + # indptr = bm.as_jax(indptr) + # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, + # shape=shape, transpose=transpose) + # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + bool_event=[True, False], + seed=[1234], + ) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): + print(f'_test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: bm.jitconn.event_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + transpose=transpose, outdim_parallel=outdim_parallel + )[0] ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): - print(f'_test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - transpose=transpose, outdim_parallel=outdim_parallel - )[0] - ) - r1 = f1(events, weights) - r1 = jax.block_until_ready(r1) - r2 = f1(events, weights) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.5] + r1 = f1(events, weights) + r1 = jax.block_until_ready(r1) + r2 = f1(events, weights) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.5 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: bm.jitconn.event_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), + argnums=0 ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.5 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - - r3 = f1(events, 3.) - r3 = jax.block_until_ready(r3) - - self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - bool_event=bool_event, - seed=1234, - x64=x64 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.4] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - for bool_event in [True, False] + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + + r3 = f1(events, 3.) + r3 = jax.block_until_ready(r3) + + self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[-1.], + w_high=[1.], + bool_event=[True, False] + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + bool_event=[True, False], + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap( + lambda e: bm.jitconn.event_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_uniform_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'bool_event={bool_event}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap( - lambda e: taichi_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - ) - - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - testcase_name=f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_high: taichi_mv_prob_uniform( - e, - w_low=0., - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - # print(r1) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu={w_mu}, ' - f'w_sigma={w_sigma}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, ] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - for bool_event in [True, False] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal: shape = {shape}, ' - f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' - f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] + + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_high: bm.jitconn.event_mv_prob_uniform( + e, + w_low=0., + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + # print(r1) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1, ], + w_mu=[0.], + w_sigma=[0.1], + bool_event=[True, False], + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal: shape = {shape}, ' + f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' + f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose = [True, False], + x64 = [True, False], + outdim_parallel = [True, False], + shape = shapes, + prob = [0.1], + bool_event = [True, False], + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose = [True, False], + x64 = [True, False], + outdim_parallel = [True, False], + shape = shapes, + prob = [0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.jit( + jax.grad( + lambda e, w_sigma: bm.jitconn.event_mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() + ) ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.jit( - jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 8a0ae444d..67c18124f 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from functools import partial import jax import jax.numpy as jnp -from absl.testing import parameterized import pytest +from absl.testing import parameterized import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -12,55 +11,38 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - -taichi_mv_prob_homo = bm.jitconn.mv_prob_homo -taichi_mv_prob_uniform = bm.jitconn.mv_prob_uniform -taichi_mv_prob_normal = bm.jitconn.mv_prob_normal +shapes = [(100, 200), (1000, 10)] class Test_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - homo_data=homo_data, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for homo_data in [-1., 1.] - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): - print(f'test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_homo(vector, + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + homo_data=[1.] + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): + print(f'test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -68,152 +50,118 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=123 outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_homo(vector, + r2 = bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - )[0] - ) - r1 = f1(events, weights) - r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: bm.jitconn.mv_prob_homo( + event, data, + conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose + )[0] ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - x64=x64, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] + r1 = f1(events, weights) + r2 = f1(events, weights) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: bm.jitconn.mv_prob_homo( + event, data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum(), + argnums=0 ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_uniform(events, + r1 = f1(events, 1.) + r2 = f1(events, 2.) + + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[-0.1], + w_high=[1.0], + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): + print(f'test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64 = {x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -222,7 +170,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_uniform(events, + r2 = bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -230,45 +178,35 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -277,107 +215,81 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - f1 = jax.grad( - lambda e, w_low, w_high: taichi_mv_prob_uniform( - e, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - - r1 = f1(events, 0., 1.) - r2 = f1(events, 0., 2.) - - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=(f'test_normal, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma},' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - seed=1234 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] + r1 = f1(events) + r2 = f1(events) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + f1 = jax.grad( + lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform( + e, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_normal(events, + + r1 = f1(events, 0., 1.) + r2 = f1(events, 0., 2.) + + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_mu=[0.], + w_sigma=[0.2] + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): + print(f'_test_normal: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu = {w_mu}, ' + f'w_sigma = {w_sigma}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -386,7 +298,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_normal(events, + r2 = bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -394,46 +306,36 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -441,66 +343,54 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - print(r1 - r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64, - testcase_name=f'test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + r1 = f1(events) + r2 = f1(events) + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + print(r1 - r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_sigma: bm.jitconn.mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + r1 = f1(events, 1.) + r2 = f1(events, 2.) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d53533247..14256cbce 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,8 +1,8 @@ # from ._coo_mv import * # from ._bsr_mv import * -from ._csr_mv import * -from ._utils import * -from ._bsr_mm import * -from ._jax_prim import * +from .csr_mv import * +from .utils import * +from .bsr_mm import * +from .jax_prim import * diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/bsr_mm.py similarity index 100% rename from brainpy/_src/math/sparse/_bsr_mm.py rename to brainpy/_src/math/sparse/bsr_mm.py diff --git a/brainpy/_src/math/sparse/_bsr_mv.py b/brainpy/_src/math/sparse/bsr_mv.py similarity index 99% rename from brainpy/_src/math/sparse/_bsr_mv.py rename to brainpy/_src/math/sparse/bsr_mv.py index a35895bc1..7dc0b683d 100644 --- a/brainpy/_src/math/sparse/_bsr_mv.py +++ b/brainpy/_src/math/sparse/bsr_mv.py @@ -11,7 +11,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy._src.dependency_check import import_brainpylib_gpu_ops from brainpy.errors import GPUOperatorNotFound diff --git a/brainpy/_src/math/sparse/_coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py similarity index 100% rename from brainpy/_src/math/sparse/_coo_mv.py rename to brainpy/_src/math/sparse/coo_mv.py diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py similarity index 99% rename from brainpy/_src/math/sparse/_csr_mv.py rename to brainpy/_src/math/sparse/csr_mv.py index 42969f435..31eec80d7 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/csr_mv.py @@ -13,7 +13,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy.errors import PackageMissingError ti = import_taichi(error_if_not_found=False) diff --git a/brainpy/_src/math/sparse/_jax_prim.py b/brainpy/_src/math/sparse/jax_prim.py similarity index 100% rename from brainpy/_src/math/sparse/_jax_prim.py rename to brainpy/_src/math/sparse/jax_prim.py diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index ec448e658..12dc6a59e 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -3,12 +3,11 @@ from functools import partial import jax +import pytest from absl.testing import parameterized -import pytest import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -25,7 +24,6 @@ def func(*args, **kwargs): return func - def compare_with_nan_tolerance(a, b, tol=1e-8): """ Compare two arrays with tolerance for NaN values. @@ -58,6 +56,7 @@ def compare_with_nan_tolerance(a, b, tol=1e-8): taichi_csr_matvec = bm.sparse.csrmv + class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_csrmv_taichi, self).__init__(*args, **kwargs) @@ -67,8 +66,8 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] + shape=[(200, 200), (10, 1000)], + homo_data=[1.] ) def test_homo(self, transpose, shape, homo_data): print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') @@ -94,8 +93,8 @@ def test_homo(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)], - v=[-1., 0., 1.] + shape=[(200, 200), (100, 1000)], + v=[1.] ) def test_homo_vmap(self, transpose, shape, v): print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') @@ -123,8 +122,8 @@ def test_homo_vmap(self, transpose, shape, v): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] + shape=[(200, 200), (10, 1000)], + homo_data=[1.] ) def test_homo_grad(self, transpose, shape, homo_data): print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') @@ -177,7 +176,7 @@ def test_homo_grad(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (2, 2000)], + shape=[(200, 200), (2, 2000)], ) def test_heter(self, transpose, shape): print(f'test_homo: transpose = {transpose} shape = {shape}') @@ -204,7 +203,7 @@ def test_heter(self, transpose, shape): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + shape=[(200, 200), (2, 2000)] ) def test_heter_vmap(self, transpose, shape): rng = bm.random.RandomState(seed=seed) @@ -230,7 +229,7 @@ def test_heter_vmap(self, transpose, shape): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + shape=[(200, 200), (2, 2000)] ) def test_heter_grad(self, transpose, shape): rng = bm.random.RandomState(seed=seed) @@ -249,8 +248,8 @@ def test_heter_grad(self, transpose, shape): dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), argnums=0) csr_f1 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), + shape=shape, + transpose=transpose).sum(), argnums=0) r1 = csr_f1(heter_data) r2 = dense_f1(dense_data) @@ -263,9 +262,9 @@ def test_heter_grad(self, transpose, shape): dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0) csr_f2 = jax.grad(lambda v: taichi_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) + shape=shape, + transpose=transpose).sum(), + argnums=0) r3 = dense_f2(vector) r4 = csr_f2(vector) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/_src/math/sparse/_utils.py b/brainpy/_src/math/sparse/utils.py similarity index 100% rename from brainpy/_src/math/sparse/_utils.py rename to brainpy/_src/math/sparse/utils.py