From 446c104bdea94161999148c57331bc4512d6b89e Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 1 Mar 2024 13:42:47 +0800 Subject: [PATCH 01/11] Update requirements-dev-raw.txt --- requirements-dev-raw.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev-raw.txt b/requirements-dev-raw.txt index 99361efa..1234b836 100644 --- a/requirements-dev-raw.txt +++ b/requirements-dev-raw.txt @@ -1,4 +1,5 @@ numpy +brainpylib jax jaxlib matplotlib From e17258f01fe824c32e455d8b77f2b0c50e6aa065 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 15:11:47 +0800 Subject: [PATCH 02/11] update tests --- .github/workflows/CI.yml | 197 --- brainpy/_src/connect/tests/test_all_time.py | 1272 +++++++++-------- brainpy/_src/math/event/__init__.py | 2 +- .../event/{_csr_matvec.py => csr_matvec.py} | 4 +- .../_src/math/event/tests/test_event_csrmv.py | 100 +- brainpy/_src/math/jitconn/__init__.py | 4 +- .../{_event_matvec.py => event_matvec.py} | 22 +- .../math/jitconn/{_matvec.py => matvec.py} | 0 .../math/jitconn/tests/test_event_matvec.py | 913 ++++++------ .../_src/math/jitconn/tests/test_matvec.py | 748 +++++----- brainpy/_src/math/sparse/__init__.py | 8 +- .../math/sparse/{_bsr_mm.py => bsr_mm.py} | 0 .../math/sparse/{_bsr_mv.py => bsr_mv.py} | 2 +- .../math/sparse/{_coo_mv.py => coo_mv.py} | 0 .../math/sparse/{_csr_mv.py => csr_mv.py} | 2 +- .../math/sparse/{_jax_prim.py => jax_prim.py} | 0 brainpy/_src/math/sparse/tests/test_csrmv.py | 33 +- .../_src/math/sparse/{_utils.py => utils.py} | 0 18 files changed, 1444 insertions(+), 1863 deletions(-) rename brainpy/_src/math/event/{_csr_matvec.py => csr_matvec.py} (99%) rename brainpy/_src/math/jitconn/{_event_matvec.py => event_matvec.py} (98%) rename brainpy/_src/math/jitconn/{_matvec.py => matvec.py} (100%) rename brainpy/_src/math/sparse/{_bsr_mm.py => bsr_mm.py} (100%) rename brainpy/_src/math/sparse/{_bsr_mv.py => bsr_mv.py} (99%) rename brainpy/_src/math/sparse/{_coo_mv.py => coo_mv.py} (100%) rename brainpy/_src/math/sparse/{_csr_mv.py => csr_mv.py} (99%) rename brainpy/_src/math/sparse/{_jax_prim.py => jax_prim.py} (100%) rename brainpy/_src/math/sparse/{_utils.py => utils.py} (100%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 95bd8eaf..66af94fc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,71 +50,6 @@ jobs: cd brainpy pytest _src/ - test_linux_with_taichi_numba: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [ "3.9", "3.10", "3.11"] - - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd brainpy - pytest _src/ - - -# test_linux_py37: -# runs-on: ubuntu-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# test_macos: runs-on: macos-latest @@ -147,135 +82,3 @@ jobs: cd brainpy pytest _src/ - test_macos_with_taichi_numba: - runs-on: macos-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.10", "3.11"] - - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd brainpy - pytest _src/ - -# test_macos_py37: -# runs-on: macos-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: [ "3.7" ] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# - - -# test_windows: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.9", "3.10", "3.11"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install -r requirements-dev.txt -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd brainpy -# pytest _src/ - - -# test_windows_py37: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install numpy>=1.21.0 -# python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver -# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz -# python -m pip install -r requirements-dev.txt -# python -m pip install tqdm brainpylib -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py index b634d6db..07422f65 100644 --- a/brainpy/_src/connect/tests/test_all_time.py +++ b/brainpy/_src/connect/tests/test_all_time.py @@ -1,18 +1,19 @@ import time from datetime import datetime -import brainpy as bp -import unittest import pytest +import brainpy as bp + +pytest.skip('skip.', allow_module_level=True) + try: - import pandas as pd + import pandas as pd - df = pd.DataFrame( - columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter', - 'time(ms)']) + df = pd.DataFrame(columns=['connector name', 'superclass', 'connect matrix size', + 'build function', 'other parameter', 'time(ms)']) except (ImportError, ModuleNotFoundError): - print('No pandas installed, skip test.') + print('No pandas installed, skip test.') # size_same = [100, 500, 2500, 12500, 25000, 37500, 50000] # size_same = [100, 500, 2500, 12500] @@ -21,644 +22,645 @@ size_same = [100, 500, 2500] size_diff = [(10, 100), (100, 1000)] + def get_ms(value): - return round(value * 1000, 4) + return round(value * 1000, 4) def insert_row(connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used): - try: - df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + try: + df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') class OneEndConnector(unittest.TestCase): - def test_gaussian_prob(self): - print() - for size in size_same: - print('GaussianProb:', size) - conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'sigma=1/include_self=False', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['GaussianProb', - # 'OneEndConnector', - # f'{size}x{size}', - # 'build_coo', - # 'sigma=1/include_self=False', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'sigma=1/include_self=False', - time_used) - - def test_grid_four(self): - print() - for size in size_same: - print('GridFour:', size) - conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_eight(self): - print() - for size in size_same: - print('GridEight:', size) - conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_n(self): - print() - for size in size_same: - print('GridN:', size) - conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False/N=2', - time_used) + def test_gaussian_prob(self): + print() + for size in size_same: + print('GaussianProb:', size) + conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'sigma=1/include_self=False', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['GaussianProb', + # 'OneEndConnector', + # f'{size}x{size}', + # 'build_coo', + # 'sigma=1/include_self=False', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'sigma=1/include_self=False', + time_used) + + def test_grid_four(self): + print() + for size in size_same: + print('GridFour:', size) + conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_eight(self): + print() + for size in size_same: + print('GridEight:', size) + conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_n(self): + print() + for size in size_same: + print('GridN:', size) + conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False/N=2', + time_used) class TwoEndConnector(unittest.TestCase): - def test_fixed_prob(self): - print() - for size in size_same: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'prob=0.1', - time_used) - - for size in size_diff: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'prob=0.1', - time_used) - - def test_fixed_pre_num(self): - print() - for size in size_same: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'pre_num=10', - time_used) - - for size in size_diff: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_fixed_post_num(self): - print() - for size in size_same: - print('FixedPostNum:', size) - conn = bp.connect.FixedPostNum(num=10, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - mat = conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num=10', - time_used) - - for size in size_diff: - print('FixedPostNum:', size) - conn = bp.connect.FixedPreNum(num=10, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_prob_dist(self): - print() - for size in size_same: - print('ProbDist:', size) - conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.5', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - def test_small_world(self): - print() - for size in size_same: - print('SmallWorld:', size) - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - def test_scale_free_ba(self): - print() - for size in size_same: - print('ScaleFreeBA:', size) - conn = bp.connect.ScaleFreeBA(m=2) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=2', - time_used) - - def test_scale_free_ba_dual(self): - print() - for size in size_same: - print('ScaleFreeBADual:', size) - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm1=2/m2=3/p=0.4', - time_used) - - def test_power_law(self): - print() - for size in size_same: - print('PowerLaw:', size) - conn = bp.connect.PowerLaw(m=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=3/p=0.4', - time_used) - - def test_one2one(self): - print() - for size in size_same: - print('One2One:', size) - conn = bp.connect.One2One() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - '', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) - - def test_all2all(self): - print() - for size in size_same: - print('All2All:', size) - conn = bp.connect.All2All() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['All2All', - # 'TwoEndConnector', - # f'{size}x{size}', - # 'build_coo', - # '', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) + def test_fixed_prob(self): + print() + for size in size_same: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'prob=0.1', + time_used) + + for size in size_diff: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'prob=0.1', + time_used) + + def test_fixed_pre_num(self): + print() + for size in size_same: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'pre_num=10', + time_used) + + for size in size_diff: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_fixed_post_num(self): + print() + for size in size_same: + print('FixedPostNum:', size) + conn = bp.connect.FixedPostNum(num=10, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + mat = conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num=10', + time_used) + + for size in size_diff: + print('FixedPostNum:', size) + conn = bp.connect.FixedPreNum(num=10, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_prob_dist(self): + print() + for size in size_same: + print('ProbDist:', size) + conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.5', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + def test_small_world(self): + print() + for size in size_same: + print('SmallWorld:', size) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + def test_scale_free_ba(self): + print() + for size in size_same: + print('ScaleFreeBA:', size) + conn = bp.connect.ScaleFreeBA(m=2) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=2', + time_used) + + def test_scale_free_ba_dual(self): + print() + for size in size_same: + print('ScaleFreeBADual:', size) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm1=2/m2=3/p=0.4', + time_used) + + def test_power_law(self): + print() + for size in size_same: + print('PowerLaw:', size) + conn = bp.connect.PowerLaw(m=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=3/p=0.4', + time_used) + + def test_one2one(self): + print() + for size in size_same: + print('One2One:', size) + conn = bp.connect.One2One() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + '', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) + + def test_all2all(self): + print() + for size in size_same: + print('All2All:', size) + conn = bp.connect.All2All() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['All2All', + # 'TwoEndConnector', + # f'{size}x{size}', + # 'build_coo', + # '', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) class TestSave(unittest.TestCase): - def test_save(self): - try: - df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', - index=False) - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + def test_save(self): + try: + df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', + index=False) + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index bdd3102a..9ebad3e9 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,2 +1,2 @@ -from ._csr_matvec import * +from .csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py similarity index 99% rename from brainpy/_src/math/event/_csr_matvec.py rename to brainpy/_src/math/event/csr_matvec.py index 6b7f7da0..2d801ee7 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -20,8 +20,8 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy.errors import PackageMissingError __all__ = [ diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 67e09d0a..6c0a2ed4 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -9,13 +9,11 @@ import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) - seed = 1234 @@ -26,7 +24,6 @@ def func(*args, **kwargs): return func -taichi_csr_matvec = bm.event.csrmv class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -37,22 +34,22 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo(self, transpose, shape, homo_data): print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (events @ dense) if transpose else (dense @ events) - r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -60,23 +57,22 @@ def test_homo(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo_vmap(self, shape, transpose, homo_data): print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -84,14 +80,14 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'events' f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, + f4 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 @@ -102,16 +98,15 @@ def test_homo_vmap(self, shape, transpose, homo_data): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo_grad(self, shape, transpose, homo_data): print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -119,31 +114,26 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(bm.sparse.csrmv))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r1 = jax.grad(sum_op(bm.sparse.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(bm.event.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, + transpose=transpose) + r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, + transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) bm.clear_buffer_memory() @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000), ] + shape=[(100, 200), (10, 1000), ] ) def test_heter(self, shape, transpose): print(f'test_heter: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -151,9 +141,9 @@ def test_heter(self, shape, transpose): heter_data = bm.as_jax(rng.random(indices.shape)) r1 = bm.sparse.csrmv(heter_data, indices, indptr, events, + shape=shape, transpose=transpose) + r2 = bm.event.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = taichi_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -161,24 +151,21 @@ def test_heter(self, shape, transpose): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] + shape=[(100, 200), (10, 1000)] ) def test_heter_vmap(self, shape, transpose): print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -187,16 +174,16 @@ def test_heter_vmap(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, + f4 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), @@ -206,15 +193,12 @@ def test_heter_vmap(self, shape, transpose): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] + shape=[(100, 200), (10, 1000)] ) def test_heter_grad(self, shape, transpose): print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -226,20 +210,20 @@ def test_heter_grad(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) r1 = jax.grad(sum_op(bm.sparse.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( + r2 = jax.grad(sum_op(bm.event.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index 6f7cddf6..bb6a3c1f 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,2 +1,2 @@ -from ._matvec import * -from ._event_matvec import * +from .matvec import * +from .event_matvec import * diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py similarity index 98% rename from brainpy/_src/math/jitconn/_event_matvec.py rename to brainpy/_src/math/jitconn/event_matvec.py index 976b72b9..3e4048c8 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -8,17 +8,17 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.jitconn._matvec import (mv_prob_homo, - mv_prob_uniform, - mv_prob_normal, - _general_checking, - raw_mv_prob_homo, - raw_mv_prob_uniform, - raw_mv_prob_normal, - _mv_prob_homo_transpose, - _mv_prob_uniform_transpose, - _mv_prob_normal_transpose, - _reverse) +from brainpy._src.math.jitconn.matvec import (mv_prob_homo, + mv_prob_uniform, + mv_prob_normal, + _general_checking, + raw_mv_prob_homo, + raw_mv_prob_uniform, + raw_mv_prob_normal, + _mv_prob_homo_transpose, + _mv_prob_uniform_transpose, + _mv_prob_normal_transpose, + _reverse) from brainpy._src.math.ndarray import _get_dtype from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/matvec.py similarity index 100% rename from brainpy/_src/math/jitconn/_matvec.py rename to brainpy/_src/math/jitconn/matvec.py diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index d8e08654..6fb8d02e 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from functools import partial import jax import jax.numpy as jnp -from absl.testing import parameterized import pytest +from absl.testing import parameterized import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -12,515 +11,419 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) - shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - -taichi_mv_prob_homo = bm.jitconn.event_mv_prob_homo -taichi_mv_prob_uniform = bm.jitconn.event_mv_prob_uniform -taichi_mv_prob_normal = bm.jitconn.event_mv_prob_normal +shapes = [(100, 200), (1000, 10)] class Test_event_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - homo_data=[-1., ], - bool_event=[True, False], - seed=[1234], - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): - print(f'_test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') - # indices = bm.as_jax(indices) - # indptr = bm.as_jax(indptr) - # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, - # shape=shape, transpose=transpose) - # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - bool_event=[True, False], - seed=[1234], + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + homo_data=[-1.], + bool_event=[True, False], + seed=[1234], + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): + print(f'_test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') + # indices = bm.as_jax(indices) + # indptr = bm.as_jax(indptr) + # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, + # shape=shape, transpose=transpose) + # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + bool_event=[True, False], + seed=[1234], + ) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): + print(f'_test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: bm.jitconn.event_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + transpose=transpose, outdim_parallel=outdim_parallel + )[0] ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): - print(f'_test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - transpose=transpose, outdim_parallel=outdim_parallel - )[0] - ) - r1 = f1(events, weights) - r1 = jax.block_until_ready(r1) - r2 = f1(events, weights) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.5] + r1 = f1(events, weights) + r1 = jax.block_until_ready(r1) + r2 = f1(events, weights) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.5 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: bm.jitconn.event_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), + argnums=0 ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.5 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - - r3 = f1(events, 3.) - r3 = jax.block_until_ready(r3) - - self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - bool_event=bool_event, - seed=1234, - x64=x64 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.4] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - for bool_event in [True, False] + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + + r3 = f1(events, 3.) + r3 = jax.block_until_ready(r3) + + self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[-1.], + w_high=[1.], + bool_event=[True, False] + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + bool_event=[True, False], + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap( + lambda e: bm.jitconn.event_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_uniform_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'bool_event={bool_event}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap( - lambda e: taichi_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - ) - - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - testcase_name=f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_high: taichi_mv_prob_uniform( - e, - w_low=0., - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - # print(r1) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu={w_mu}, ' - f'w_sigma={w_sigma}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, ] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - for bool_event in [True, False] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal: shape = {shape}, ' - f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' - f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] + + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_high: bm.jitconn.event_mv_prob_uniform( + e, + w_low=0., + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + # print(r1) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1, ], + w_mu=[0.], + w_sigma=[0.1], + bool_event=[True, False], + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal: shape = {shape}, ' + f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' + f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose = [True, False], + x64 = [True, False], + outdim_parallel = [True, False], + shape = shapes, + prob = [0.1], + bool_event = [True, False], + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose = [True, False], + x64 = [True, False], + outdim_parallel = [True, False], + shape = shapes, + prob = [0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.jit( + jax.grad( + lambda e, w_sigma: bm.jitconn.event_mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() + ) ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.jit( - jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 8a0ae444..67c18124 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from functools import partial import jax import jax.numpy as jnp -from absl.testing import parameterized import pytest +from absl.testing import parameterized import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -12,55 +11,38 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - -taichi_mv_prob_homo = bm.jitconn.mv_prob_homo -taichi_mv_prob_uniform = bm.jitconn.mv_prob_uniform -taichi_mv_prob_normal = bm.jitconn.mv_prob_normal +shapes = [(100, 200), (1000, 10)] class Test_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - homo_data=homo_data, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for homo_data in [-1., 1.] - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): - print(f'test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_homo(vector, + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + homo_data=[1.] + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): + print(f'test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -68,152 +50,118 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=123 outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_homo(vector, + r2 = bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - )[0] - ) - r1 = f1(events, weights) - r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: bm.jitconn.mv_prob_homo( + event, data, + conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose + )[0] ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - x64=x64, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] + r1 = f1(events, weights) + r2 = f1(events, weights) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: bm.jitconn.mv_prob_homo( + event, data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum(), + argnums=0 ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_uniform(events, + r1 = f1(events, 1.) + r2 = f1(events, 2.) + + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[-0.1], + w_high=[1.0], + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): + print(f'test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64 = {x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -222,7 +170,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_uniform(events, + r2 = bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -230,45 +178,35 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -277,107 +215,81 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - f1 = jax.grad( - lambda e, w_low, w_high: taichi_mv_prob_uniform( - e, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - - r1 = f1(events, 0., 1.) - r2 = f1(events, 0., 2.) - - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=(f'test_normal, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma},' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - seed=1234 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] + r1 = f1(events) + r2 = f1(events) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + f1 = jax.grad( + lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform( + e, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_normal(events, + + r1 = f1(events, 0., 1.) + r2 = f1(events, 0., 2.) + + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_mu=[0.], + w_sigma=[0.2] + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): + print(f'_test_normal: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu = {w_mu}, ' + f'w_sigma = {w_sigma}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -386,7 +298,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_normal(events, + r2 = bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -394,46 +306,36 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -441,66 +343,54 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - print(r1 - r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64, - testcase_name=f'test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + r1 = f1(events) + r2 = f1(events) + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + print(r1 - r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_sigma: bm.jitconn.mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + r1 = f1(events, 1.) + r2 = f1(events, 2.) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d5353324..14256cbc 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,8 +1,8 @@ # from ._coo_mv import * # from ._bsr_mv import * -from ._csr_mv import * -from ._utils import * -from ._bsr_mm import * -from ._jax_prim import * +from .csr_mv import * +from .utils import * +from .bsr_mm import * +from .jax_prim import * diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/bsr_mm.py similarity index 100% rename from brainpy/_src/math/sparse/_bsr_mm.py rename to brainpy/_src/math/sparse/bsr_mm.py diff --git a/brainpy/_src/math/sparse/_bsr_mv.py b/brainpy/_src/math/sparse/bsr_mv.py similarity index 99% rename from brainpy/_src/math/sparse/_bsr_mv.py rename to brainpy/_src/math/sparse/bsr_mv.py index a35895bc..7dc0b683 100644 --- a/brainpy/_src/math/sparse/_bsr_mv.py +++ b/brainpy/_src/math/sparse/bsr_mv.py @@ -11,7 +11,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy._src.dependency_check import import_brainpylib_gpu_ops from brainpy.errors import GPUOperatorNotFound diff --git a/brainpy/_src/math/sparse/_coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py similarity index 100% rename from brainpy/_src/math/sparse/_coo_mv.py rename to brainpy/_src/math/sparse/coo_mv.py diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py similarity index 99% rename from brainpy/_src/math/sparse/_csr_mv.py rename to brainpy/_src/math/sparse/csr_mv.py index 42969f43..31eec80d 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/csr_mv.py @@ -13,7 +13,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy.errors import PackageMissingError ti = import_taichi(error_if_not_found=False) diff --git a/brainpy/_src/math/sparse/_jax_prim.py b/brainpy/_src/math/sparse/jax_prim.py similarity index 100% rename from brainpy/_src/math/sparse/_jax_prim.py rename to brainpy/_src/math/sparse/jax_prim.py diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index ec448e65..12dc6a59 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -3,12 +3,11 @@ from functools import partial import jax +import pytest from absl.testing import parameterized -import pytest import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -25,7 +24,6 @@ def func(*args, **kwargs): return func - def compare_with_nan_tolerance(a, b, tol=1e-8): """ Compare two arrays with tolerance for NaN values. @@ -58,6 +56,7 @@ def compare_with_nan_tolerance(a, b, tol=1e-8): taichi_csr_matvec = bm.sparse.csrmv + class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_csrmv_taichi, self).__init__(*args, **kwargs) @@ -67,8 +66,8 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] + shape=[(200, 200), (10, 1000)], + homo_data=[1.] ) def test_homo(self, transpose, shape, homo_data): print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') @@ -94,8 +93,8 @@ def test_homo(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)], - v=[-1., 0., 1.] + shape=[(200, 200), (100, 1000)], + v=[1.] ) def test_homo_vmap(self, transpose, shape, v): print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') @@ -123,8 +122,8 @@ def test_homo_vmap(self, transpose, shape, v): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] + shape=[(200, 200), (10, 1000)], + homo_data=[1.] ) def test_homo_grad(self, transpose, shape, homo_data): print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') @@ -177,7 +176,7 @@ def test_homo_grad(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (2, 2000)], + shape=[(200, 200), (2, 2000)], ) def test_heter(self, transpose, shape): print(f'test_homo: transpose = {transpose} shape = {shape}') @@ -204,7 +203,7 @@ def test_heter(self, transpose, shape): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + shape=[(200, 200), (2, 2000)] ) def test_heter_vmap(self, transpose, shape): rng = bm.random.RandomState(seed=seed) @@ -230,7 +229,7 @@ def test_heter_vmap(self, transpose, shape): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + shape=[(200, 200), (2, 2000)] ) def test_heter_grad(self, transpose, shape): rng = bm.random.RandomState(seed=seed) @@ -249,8 +248,8 @@ def test_heter_grad(self, transpose, shape): dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), argnums=0) csr_f1 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), + shape=shape, + transpose=transpose).sum(), argnums=0) r1 = csr_f1(heter_data) r2 = dense_f1(dense_data) @@ -263,9 +262,9 @@ def test_heter_grad(self, transpose, shape): dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0) csr_f2 = jax.grad(lambda v: taichi_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) + shape=shape, + transpose=transpose).sum(), + argnums=0) r3 = dense_f2(vector) r4 = csr_f2(vector) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/_src/math/sparse/_utils.py b/brainpy/_src/math/sparse/utils.py similarity index 100% rename from brainpy/_src/math/sparse/_utils.py rename to brainpy/_src/math/sparse/utils.py From 5cbc5243c3688828eaaded84555c5dc85e4262e0 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 15:36:51 +0800 Subject: [PATCH 03/11] recovery --- brainpy/_src/math/op_register/tests/test_taichi_based.py | 2 +- .../_src/math/op_register/tests/test_taichi_clean_cache.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 0fbcca3b..2d3af382 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import pytest -pytest.skip('Old implementation.', allow_module_level=True) +# pytest.skip('Old implementation.', allow_module_level=True) import brainpy.math as bm from brainpy._src.dependency_check import import_taichi diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 51c964b2..fc9beccb 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -9,8 +9,8 @@ if ti is None: pytest.skip('no taichi', allow_module_level=True) -if not platform.platform().startswith('Windows'): - pytest.skip(allow_module_level=True) +# if not platform.platform().startswith('Windows'): +# pytest.skip(allow_module_level=True) @ti.func def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: From e6d5a3f78a6eaaaa7e02c73fa6537651b8419265 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 16:09:14 +0800 Subject: [PATCH 04/11] upgrade tests --- brainpy/_src/dnn/tests/test_linear.py | 34 +++++++++---------- brainpy/_src/dnn/tests/test_mode.py | 2 -- .../_src/math/op_register/taichi_aot_based.py | 7 ++++ .../op_register/tests/test_taichi_based.py | 3 +- brainpy/_src/math/sparse/tests/test_csrmv.py | 27 ++++++--------- requirements-dev-raw.txt | 13 ------- 6 files changed, 36 insertions(+), 50 deletions(-) delete mode 100644 requirements-dev-raw.txt diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 422f161f..6cc44538 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -20,7 +20,7 @@ def __init__(self, *args, **kwargs): size=[(10,), (20, 10), (5, 8, 10)], - num_out=[20, 10, 5] + num_out=[20,] ) def test_Dense1(self, size, num_out): bm.random.seed() @@ -131,8 +131,8 @@ def test_EventCSRLinear(self, conn): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], + prob=[0.1], + weight=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPHomoLinear(self, prob, weight, shape): @@ -144,9 +144,9 @@ def test_JitFPHomoLinear(self, prob, weight, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], + prob=[0.1], + w_low=[-0.01, ], + w_high=[0.01, ], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): @@ -158,9 +158,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], + prob=[0.1], + w_mu=[-0.01], + w_sigma=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): @@ -172,8 +172,8 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], + prob=[0.1], + weight=[0.01,], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPHomoLinear(self, prob, weight, shape): @@ -187,9 +187,9 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], + prob=[0.1], + w_low=[-0.01], + w_high=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): @@ -203,9 +203,9 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], + prob=[0.1], + w_mu=[-0.01], + w_sigma=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index f0c67da1..10e9eeda 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -4,7 +4,6 @@ import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -63,7 +62,6 @@ def test_Conv2_NonBatching(self): mode=bm.NonBatchingMode()) output = layer(input) bm.clear_buffer_memory() - bm.clear_buffer_memory() @parameterized.product( mode=[bm.TrainingMode(), diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 7fac4452..f9328906 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -16,6 +16,7 @@ from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call +from brainpy.errors import PackageMissingError from brainpy._src.dependency_check import (import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops) @@ -485,10 +486,16 @@ def _taichi_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): def register_taichi_aot_mlir_cpu_translation_rule(primitive, cpu_kernel): + if import_taichi(error_if_not_found=False) is None: + raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') + rule = partial(_taichi_mlir_cpu_translation_rule, cpu_kernel) mlir.register_lowering(primitive, rule, platform='cpu') def register_taichi_aot_mlir_gpu_translation_rule(primitive, gpu_kernel): + if import_taichi(error_if_not_found=False) is None: + raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') + rule = partial(_taichi_mlir_gpu_translation_rule, gpu_kernel) mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 2d3af382..199dce98 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -1,7 +1,6 @@ import jax import jax.numpy as jnp import pytest -# pytest.skip('Old implementation.', allow_module_level=True) import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -55,7 +54,7 @@ def event_ell_gpu(indices: ti.types.ndarray(ndim=2), def test_taichi_op_register(): s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) + indices = bm.random.randint(0, s, (s, 100)) vector = bm.random.rand(s) < 0.1 weight = bm.array([1.0]) diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 12dc6a59..074ac6a9 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -54,9 +54,6 @@ def compare_with_nan_tolerance(a, b, tol=1e-8): return bm.allclose(a_non_nan, b_non_nan, atol=tol) -taichi_csr_matvec = bm.sparse.csrmv - - class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_csrmv_taichi, self).__init__(*args, **kwargs) @@ -86,7 +83,7 @@ def test_homo(self, transpose, shape, homo_data): dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = bm.sparse.csrmv(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -112,7 +109,7 @@ def test_homo_vmap(self, transpose, shape, v): dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) - f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(dense_data) r2 = jax.vmap(f2)(homo_data) @@ -147,7 +144,7 @@ def test_homo_grad(self, transpose, shape, homo_data): ((dense * a) @ vector).sum()), argnums=0) r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(taichi_csr_matvec))( + r2 = jax.grad(sum_op(bm.sparse.csrmv))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -157,7 +154,7 @@ def test_homo_grad(self, transpose, shape, homo_data): dense_data = dense * homo_data dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) r3 = dense_f2(vector) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + r4 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -167,7 +164,7 @@ def test_homo_grad(self, transpose, shape, homo_data): ((dense * a) @ v).sum()), argnums=(0, 1)) r5 = dense_f3(homo_data, vector) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) @@ -195,7 +192,7 @@ def test_heter(self, transpose, shape): dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = bm.sparse.csrmv(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(compare_with_nan_tolerance(r1, r2)) @@ -221,7 +218,7 @@ def test_heter_vmap(self, transpose, shape): shape=shape))(heter_data) f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) - f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(dense_data) r2 = jax.vmap(f2)(heter_data) @@ -245,9 +242,8 @@ def test_heter_grad(self, transpose, shape): vector = bm.as_jax(vector) # grad 'data' - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), - argnums=0) - csr_f1 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, + dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), argnums=0) + csr_f1 = jax.grad(lambda a: bm.sparse.csrmv(a, indices, indptr, vector, shape=shape, transpose=transpose).sum(), argnums=0) @@ -259,9 +255,8 @@ def test_heter_grad(self, transpose, shape): self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), - argnums=0) - csr_f2 = jax.grad(lambda v: taichi_csr_matvec(heter_data, indices, indptr, v, + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0) + csr_f2 = jax.grad(lambda v: bm.sparse.csrmv(heter_data, indices, indptr, v, shape=shape, transpose=transpose).sum(), argnums=0) diff --git a/requirements-dev-raw.txt b/requirements-dev-raw.txt deleted file mode 100644 index 1234b836..00000000 --- a/requirements-dev-raw.txt +++ /dev/null @@ -1,13 +0,0 @@ -numpy -brainpylib -jax -jaxlib -matplotlib -msgpack -tqdm -pathos - - -# test requirements -pytest -absl-py From 343476ff9186814c757a45a3fd5784066dd0cdec Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Fri, 1 Mar 2024 18:14:36 +0800 Subject: [PATCH 05/11] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b6902bde..e9f3cc01 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ on: jobs: test_linux: - runs-on: ubuntu-latest + runs-on: self-hosted strategy: fail-fast: false matrix: From a0bda45df9ce5a98f0429bf94e133946e988025b Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 2 Mar 2024 16:38:07 +0800 Subject: [PATCH 06/11] Revert "Update CI.yml" This reverts commit 343476ff9186814c757a45a3fd5784066dd0cdec. --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e9f3cc01..b6902bde 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ on: jobs: test_linux: - runs-on: self-hosted + runs-on: ubuntu-latest strategy: fail-fast: false matrix: From 31f9643645831203bc40729a4fa594b83ef8ca28 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 2 Mar 2024 16:39:37 +0800 Subject: [PATCH 07/11] update tests --- .../math/op_register/tests/test_numba_based.py | 8 ++++---- .../op_register/tests/test_taichi_clean_cache.py | 16 +++++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py index dc093f62..28b80d0f 100644 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ b/brainpy/_src/math/op_register/tests/test_numba_based.py @@ -1,14 +1,16 @@ -import pytest import jax.core -import brainpy.math as bm +import pytest +import brainpy.math as bm from brainpy._src.dependency_check import import_numba + numba = import_numba(error_if_not_found=False) if numba is None: pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') + @numba.njit(fastmath=True) def numba_event_csrmv(weight, indices, vector, outs): outs.fill(0) @@ -33,5 +35,3 @@ def test_event_ELL(): call(1000) call(100) bm.clear_buffer_memory() - - diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index fc9beccb..5b27b2fd 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -1,16 +1,15 @@ -import brainpy.math as bm import jax import jax.numpy as jnp -import platform -import pytest + +import brainpy.math as bm +import taichi as ti from brainpy._src.dependency_check import import_taichi ti = import_taichi(error_if_not_found=False) if ti is None: + import pytest pytest.skip('no taichi', allow_module_level=True) -# if not platform.platform().startswith('Windows'): -# pytest.skip(allow_module_level=True) @ti.func def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: @@ -21,6 +20,7 @@ def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): out[index] += weight_val + @ti.kernel def event_ell_cpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), @@ -34,11 +34,13 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) + prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) + def test_taichi_clean_cache(): s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) + indices = bm.random.randint(0, s, (s, 100)) vector = bm.random.rand(s) < 0.1 weight = bm.array([1.0]) @@ -55,4 +57,4 @@ def test_taichi_clean_cache(): print('kernels: ', bm.check_kernels_count()) -# test_taichi_clean_cache() \ No newline at end of file +# test_taichi_clean_cache() From b351a482a21c154cf0208807b28d2309f47f06ba Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 2 Mar 2024 16:42:10 +0800 Subject: [PATCH 08/11] update tests --- brainpy/_src/math/sparse/tests/test_csrmv.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 074ac6a9..40bcbb70 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -83,7 +83,7 @@ def test_homo(self, transpose, shape, homo_data): dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = bm.sparse.csrmv(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = bm.sparse.csrmv(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -144,8 +144,7 @@ def test_homo_grad(self, transpose, shape, homo_data): ((dense * a) @ vector).sum()), argnums=0) r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(bm.sparse.csrmv))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(bm.sparse.csrmv))(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -155,7 +154,7 @@ def test_homo_grad(self, transpose, shape, homo_data): dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) r3 = dense_f2(vector) r4 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -165,7 +164,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=(0, 1)) r5 = dense_f3(homo_data, vector) r6 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) From 1774fdae9121e2465dfdfbdd039c2cc298f196ef Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 2 Mar 2024 16:45:02 +0800 Subject: [PATCH 09/11] update CI --- .github/workflows/CI.yml | 88 +++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b6902bde..c8b057ce 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -11,11 +11,11 @@ on: branches: - '**' # matches every branch -#on: -# push: -# branches: [ master ] -# pull_request: -# branches: [ master ] + +permissions: + contents: read # to fetch code + actions: write # to cancel previous workflows + jobs: @@ -24,9 +24,13 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -49,24 +53,28 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Test with pytest - run: | - cd brainpy - pytest _src/ + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + pip uninstall brainpy -y + python setup.py install + - name: Test with pytest + run: | + cd brainpy + pytest _src/ test_windows: @@ -74,21 +82,25 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install -r requirements-dev.txt - pip uninstall brainpy -y - python setup.py install - - name: Test with pytest - run: | - cd brainpy - pytest _src/ + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements-dev.txt + pip uninstall brainpy -y + python setup.py install + - name: Test with pytest + run: | + cd brainpy + pytest _src/ From b4f0b7da6879963f69844469cce2a4aa8180ac15 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 2 Mar 2024 16:47:35 +0800 Subject: [PATCH 10/11] update CI --- .github/workflows/CI.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c8b057ce..5b6b7357 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,7 +16,9 @@ permissions: contents: read # to fetch code actions: write # to cancel previous workflows - +concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: true jobs: test_linux: From 6b68831096ec76b5e195375435862d48d1503d67 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 2 Mar 2024 16:50:47 +0800 Subject: [PATCH 11/11] update CI --- .github/workflows/CI.yml | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5b6b7357..7f46c959 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,10 +16,12 @@ permissions: contents: read # to fetch code actions: write # to cancel previous workflows +# This is what will cancel the workflow concurrency: - group: ${{ github.head_ref }} + group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true + jobs: test_linux: runs-on: ubuntu-latest @@ -30,10 +32,12 @@ jobs: steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -59,10 +63,12 @@ jobs: steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -88,10 +94,12 @@ jobs: steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: