diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d29b07ebc..7f46c9593 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,12 +16,10 @@ permissions: contents: read # to fetch code actions: write # to cancel previous workflows - -#on: -# push: -# branches: [ master ] -# pull_request: -# branches: [ master ] +# This is what will cancel the workflow +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: @@ -30,14 +28,16 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + uses: styfle/cancel-workflow-action@0.12.1 with: - access_token: ${{ github.token }} + access_token: ${{ github.token }} - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -53,15 +53,22 @@ jobs: cd brainpy pytest _src/ - test_linux_with_taichi_numba: - runs-on: ubuntu-latest + + test_macos: + runs-on: macos-latest strategy: fail-fast: false matrix: - python-version: [ "3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -69,109 +76,41 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi + if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | cd brainpy pytest _src/ - test_macos: - runs-on: macos-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.10", "3.11"] - - steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Test with pytest - run: | - cd brainpy - pytest -n auto --tb=short _src/ - - test_windows: + runs-on: windows-latest strategy: fail-fast: false matrix: - os: [ win-2019-16core ] - arch: [ AMD64 ] - python-version: ["3.9", "3.10", "3.11"] - runs-on: ${{ matrix.os }} - - steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install -r requirements-dev.txt - pip uninstall brainpy -y - python setup.py install - - name: Test with pytest - run: | - cd brainpy - pytest _src/ - - test_macos_with_taichi_numba: - runs-on: macos-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd brainpy - pytest _src/ - + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements-dev.txt + pip uninstall brainpy -y + python setup.py install + - name: Test with pytest + run: | + cd brainpy + pytest _src/ diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py index b634d6dbe..07422f65e 100644 --- a/brainpy/_src/connect/tests/test_all_time.py +++ b/brainpy/_src/connect/tests/test_all_time.py @@ -1,18 +1,19 @@ import time from datetime import datetime -import brainpy as bp -import unittest import pytest +import brainpy as bp + +pytest.skip('skip.', allow_module_level=True) + try: - import pandas as pd + import pandas as pd - df = pd.DataFrame( - columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter', - 'time(ms)']) + df = pd.DataFrame(columns=['connector name', 'superclass', 'connect matrix size', + 'build function', 'other parameter', 'time(ms)']) except (ImportError, ModuleNotFoundError): - print('No pandas installed, skip test.') + print('No pandas installed, skip test.') # size_same = [100, 500, 2500, 12500, 25000, 37500, 50000] # size_same = [100, 500, 2500, 12500] @@ -21,644 +22,645 @@ size_same = [100, 500, 2500] size_diff = [(10, 100), (100, 1000)] + def get_ms(value): - return round(value * 1000, 4) + return round(value * 1000, 4) def insert_row(connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used): - try: - df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + try: + df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') class OneEndConnector(unittest.TestCase): - def test_gaussian_prob(self): - print() - for size in size_same: - print('GaussianProb:', size) - conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'sigma=1/include_self=False', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['GaussianProb', - # 'OneEndConnector', - # f'{size}x{size}', - # 'build_coo', - # 'sigma=1/include_self=False', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'sigma=1/include_self=False', - time_used) - - def test_grid_four(self): - print() - for size in size_same: - print('GridFour:', size) - conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_eight(self): - print() - for size in size_same: - print('GridEight:', size) - conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_n(self): - print() - for size in size_same: - print('GridN:', size) - conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False/N=2', - time_used) + def test_gaussian_prob(self): + print() + for size in size_same: + print('GaussianProb:', size) + conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'sigma=1/include_self=False', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['GaussianProb', + # 'OneEndConnector', + # f'{size}x{size}', + # 'build_coo', + # 'sigma=1/include_self=False', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'sigma=1/include_self=False', + time_used) + + def test_grid_four(self): + print() + for size in size_same: + print('GridFour:', size) + conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_eight(self): + print() + for size in size_same: + print('GridEight:', size) + conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_n(self): + print() + for size in size_same: + print('GridN:', size) + conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False/N=2', + time_used) class TwoEndConnector(unittest.TestCase): - def test_fixed_prob(self): - print() - for size in size_same: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'prob=0.1', - time_used) - - for size in size_diff: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'prob=0.1', - time_used) - - def test_fixed_pre_num(self): - print() - for size in size_same: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'pre_num=10', - time_used) - - for size in size_diff: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_fixed_post_num(self): - print() - for size in size_same: - print('FixedPostNum:', size) - conn = bp.connect.FixedPostNum(num=10, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - mat = conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num=10', - time_used) - - for size in size_diff: - print('FixedPostNum:', size) - conn = bp.connect.FixedPreNum(num=10, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_prob_dist(self): - print() - for size in size_same: - print('ProbDist:', size) - conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.5', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - def test_small_world(self): - print() - for size in size_same: - print('SmallWorld:', size) - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - def test_scale_free_ba(self): - print() - for size in size_same: - print('ScaleFreeBA:', size) - conn = bp.connect.ScaleFreeBA(m=2) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=2', - time_used) - - def test_scale_free_ba_dual(self): - print() - for size in size_same: - print('ScaleFreeBADual:', size) - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm1=2/m2=3/p=0.4', - time_used) - - def test_power_law(self): - print() - for size in size_same: - print('PowerLaw:', size) - conn = bp.connect.PowerLaw(m=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=3/p=0.4', - time_used) - - def test_one2one(self): - print() - for size in size_same: - print('One2One:', size) - conn = bp.connect.One2One() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - '', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) - - def test_all2all(self): - print() - for size in size_same: - print('All2All:', size) - conn = bp.connect.All2All() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['All2All', - # 'TwoEndConnector', - # f'{size}x{size}', - # 'build_coo', - # '', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) + def test_fixed_prob(self): + print() + for size in size_same: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'prob=0.1', + time_used) + + for size in size_diff: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'prob=0.1', + time_used) + + def test_fixed_pre_num(self): + print() + for size in size_same: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'pre_num=10', + time_used) + + for size in size_diff: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_fixed_post_num(self): + print() + for size in size_same: + print('FixedPostNum:', size) + conn = bp.connect.FixedPostNum(num=10, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + mat = conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num=10', + time_used) + + for size in size_diff: + print('FixedPostNum:', size) + conn = bp.connect.FixedPreNum(num=10, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_prob_dist(self): + print() + for size in size_same: + print('ProbDist:', size) + conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.5', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + def test_small_world(self): + print() + for size in size_same: + print('SmallWorld:', size) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + def test_scale_free_ba(self): + print() + for size in size_same: + print('ScaleFreeBA:', size) + conn = bp.connect.ScaleFreeBA(m=2) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=2', + time_used) + + def test_scale_free_ba_dual(self): + print() + for size in size_same: + print('ScaleFreeBADual:', size) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm1=2/m2=3/p=0.4', + time_used) + + def test_power_law(self): + print() + for size in size_same: + print('PowerLaw:', size) + conn = bp.connect.PowerLaw(m=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=3/p=0.4', + time_used) + + def test_one2one(self): + print() + for size in size_same: + print('One2One:', size) + conn = bp.connect.One2One() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + '', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) + + def test_all2all(self): + print() + for size in size_same: + print('All2All:', size) + conn = bp.connect.All2All() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['All2All', + # 'TwoEndConnector', + # f'{size}x{size}', + # 'build_coo', + # '', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) class TestSave(unittest.TestCase): - def test_save(self): - try: - df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', - index=False) - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + def test_save(self): + try: + df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', + index=False) + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 422f161f1..6cc445383 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -20,7 +20,7 @@ def __init__(self, *args, **kwargs): size=[(10,), (20, 10), (5, 8, 10)], - num_out=[20, 10, 5] + num_out=[20,] ) def test_Dense1(self, size, num_out): bm.random.seed() @@ -131,8 +131,8 @@ def test_EventCSRLinear(self, conn): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], + prob=[0.1], + weight=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPHomoLinear(self, prob, weight, shape): @@ -144,9 +144,9 @@ def test_JitFPHomoLinear(self, prob, weight, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], + prob=[0.1], + w_low=[-0.01, ], + w_high=[0.01, ], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): @@ -158,9 +158,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], + prob=[0.1], + w_mu=[-0.01], + w_sigma=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): @@ -172,8 +172,8 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], + prob=[0.1], + weight=[0.01,], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPHomoLinear(self, prob, weight, shape): @@ -187,9 +187,9 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], + prob=[0.1], + w_low=[-0.01], + w_high=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): @@ -203,9 +203,9 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], + prob=[0.1], + w_mu=[-0.01], + w_sigma=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index f0c67da12..10e9eeda2 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -4,7 +4,6 @@ import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -63,7 +62,6 @@ def test_Conv2_NonBatching(self): mode=bm.NonBatchingMode()) output = layer(input) bm.clear_buffer_memory() - bm.clear_buffer_memory() @parameterized.product( mode=[bm.TrainingMode(), diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index 6db945ff2..d068f2079 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -3,9 +3,14 @@ from absl.testing import parameterized +import pytest import brainpy as bp import brainpy.math as bm from brainpy._src.dynold.synapses import abstract_models +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) class Test_Abstract_Synapse(parameterized.TestCase): diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index bdd3102a3..9ebad3e94 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,2 +1,2 @@ -from ._csr_matvec import * +from .csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py similarity index 99% rename from brainpy/_src/math/event/_csr_matvec.py rename to brainpy/_src/math/event/csr_matvec.py index 6b7f7da02..9890838e7 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -20,8 +20,8 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy.errors import PackageMissingError __all__ = [ @@ -30,6 +30,7 @@ ti = import_taichi(error_if_not_found=False) + def csrmv( data: Union[float, jax.Array], indices: jax.Array, diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 67e09d0a4..6c0a2ed47 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -9,13 +9,11 @@ import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) - seed = 1234 @@ -26,7 +24,6 @@ def func(*args, **kwargs): return func -taichi_csr_matvec = bm.event.csrmv class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -37,22 +34,22 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo(self, transpose, shape, homo_data): print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (events @ dense) if transpose else (dense @ events) - r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -60,23 +57,22 @@ def test_homo(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo_vmap(self, shape, transpose, homo_data): print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -84,14 +80,14 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'events' f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, + f4 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 @@ -102,16 +98,15 @@ def test_homo_vmap(self, shape, transpose, homo_data): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo_grad(self, shape, transpose, homo_data): print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -119,31 +114,26 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(bm.sparse.csrmv))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r1 = jax.grad(sum_op(bm.sparse.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(bm.event.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, + transpose=transpose) + r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, + transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) bm.clear_buffer_memory() @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000), ] + shape=[(100, 200), (10, 1000), ] ) def test_heter(self, shape, transpose): print(f'test_heter: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -151,9 +141,9 @@ def test_heter(self, shape, transpose): heter_data = bm.as_jax(rng.random(indices.shape)) r1 = bm.sparse.csrmv(heter_data, indices, indptr, events, + shape=shape, transpose=transpose) + r2 = bm.event.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = taichi_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -161,24 +151,21 @@ def test_heter(self, shape, transpose): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] + shape=[(100, 200), (10, 1000)] ) def test_heter_vmap(self, shape, transpose): print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -187,16 +174,16 @@ def test_heter_vmap(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, + f4 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), @@ -206,15 +193,12 @@ def test_heter_vmap(self, shape, transpose): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] + shape=[(100, 200), (10, 1000)] ) def test_heter_grad(self, shape, transpose): print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -226,20 +210,20 @@ def test_heter_grad(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) r1 = jax.grad(sum_op(bm.sparse.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( + r2 = jax.grad(sum_op(bm.event.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index 6f7cddf6a..bb6a3c1f4 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,2 +1,2 @@ -from ._matvec import * -from ._event_matvec import * +from .matvec import * +from .event_matvec import * diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py similarity index 56% rename from brainpy/_src/math/jitconn/_event_matvec.py rename to brainpy/_src/math/jitconn/event_matvec.py index ac62bbfaf..279980380 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -8,17 +8,17 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.jitconn._matvec import (mv_prob_homo, - mv_prob_uniform, - mv_prob_normal, - _general_checking, - raw_mv_prob_homo, - raw_mv_prob_uniform, - raw_mv_prob_normal, - _mv_prob_homo_transpose, - _mv_prob_uniform_transpose, - _mv_prob_normal_transpose, - _reverse) +from brainpy._src.math.jitconn.matvec import (mv_prob_homo, + mv_prob_uniform, + mv_prob_normal, + _general_checking, + raw_mv_prob_homo, + raw_mv_prob_uniform, + raw_mv_prob_normal, + _mv_prob_homo_transpose, + _mv_prob_uniform_transpose, + _mv_prob_normal_transpose, + _reverse) from brainpy._src.math.ndarray import _get_dtype from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError @@ -45,743 +45,6 @@ def event_mv_prob_homo( if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -### BRAINPYLIB ### - -def event_mv_prob_homo_brainpylib( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - weight = jnp.atleast_1d(jnp.asarray(weight)) - conn_prob = jnp.atleast_1d(jnp.asarray(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - return r - - -event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform_brainpylib( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal_brainpylib( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ - - -def _event_matvec_prob_homo_abstract( - events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_homo_cpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_homo' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_homo_gpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_homo_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1], ) - - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, weight, clen, seed = primals - event_dot, weight_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(weight_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(event_dot) is ad.Zero: - raise ValueError - dr = mv_prob_homo_p.bind(event_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(event_dot) is ad.Zero: - dr = mv_prob_homo_p.bind(events, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - dr = mv_prob_homo_p.bind(event_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, dr - - -def _event_matvec_prob_homo_transpose( - ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -event_mv_prob_homo_p = Primitive('event_mv_prob_homo') -event_mv_prob_homo_p.multiple_results = True -event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) -event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation -ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp -ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose -register_general_batching(event_mv_prob_homo_p) - - -def _event_matvec_prob_uniform_abstract( - events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_uniform_cpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_uniform_gpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_low, w_high, clen, seed = primals - events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - outdim_parallel=outdim_parallel, - transpose=transpose) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(events_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_uniform_transpose( - ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') -event_mv_prob_uniform_p.multiple_results = True -event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) -event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation -register_general_batching(event_mv_prob_uniform_p) -ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp -ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose - - -def _event_matvec_prob_normal_abstract( - events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - assert w_mu.dtype == w_sigma.dtype - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _get_types(event_shape): - event_type = event_shape.element_type() - if event_type == jnp.bool_: - event_type = b'_bool' - out_dtype = dtypes.canonicalize_dtype(float) - elif event_type == jnp.float32: - event_type = b'_float' - out_dtype = event_shape.element_type() - elif event_type == jnp.float64: - event_type = b'_double' - out_dtype = event_shape.element_type() - else: - raise TypeError - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - return out_dtype, event_type, type_name - - -def _event_matvec_prob_normal_cpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_normal' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_normal_gpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_normal_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_mu, w_sigma, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_mu, w_sigma, clen, seed = primals - events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(events_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_normal_transpose( - ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -event_mv_prob_normal_p = Primitive('event_mv_prob_normal') -event_mv_prob_normal_p.multiple_results = True -event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) -event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation -register_general_batching(event_mv_prob_normal_p) -ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp -ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose - - -### TAICHI ### - -def event_mv_prob_homo_taichi( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ events = as_jax(events) weight = as_jax(weight) if jnp.ndim(weight) < 1: diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/matvec.py similarity index 100% rename from brainpy/_src/math/jitconn/_matvec.py rename to brainpy/_src/math/jitconn/matvec.py diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index d8e086540..6fb8d02ec 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from functools import partial import jax import jax.numpy as jnp -from absl.testing import parameterized import pytest +from absl.testing import parameterized import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -12,515 +11,419 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) - shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - -taichi_mv_prob_homo = bm.jitconn.event_mv_prob_homo -taichi_mv_prob_uniform = bm.jitconn.event_mv_prob_uniform -taichi_mv_prob_normal = bm.jitconn.event_mv_prob_normal +shapes = [(100, 200), (1000, 10)] class Test_event_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - homo_data=[-1., ], - bool_event=[True, False], - seed=[1234], - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): - print(f'_test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') - # indices = bm.as_jax(indices) - # indptr = bm.as_jax(indptr) - # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, - # shape=shape, transpose=transpose) - # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - bool_event=[True, False], - seed=[1234], + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + homo_data=[-1.], + bool_event=[True, False], + seed=[1234], + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): + print(f'_test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') + # indices = bm.as_jax(indices) + # indptr = bm.as_jax(indptr) + # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, + # shape=shape, transpose=transpose) + # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + bool_event=[True, False], + seed=[1234], + ) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): + print(f'_test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: bm.jitconn.event_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + transpose=transpose, outdim_parallel=outdim_parallel + )[0] ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): - print(f'_test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - transpose=transpose, outdim_parallel=outdim_parallel - )[0] - ) - r1 = f1(events, weights) - r1 = jax.block_until_ready(r1) - r2 = f1(events, weights) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.5] + r1 = f1(events, weights) + r1 = jax.block_until_ready(r1) + r2 = f1(events, weights) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.5 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: bm.jitconn.event_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), + argnums=0 ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.5 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - - r3 = f1(events, 3.) - r3 = jax.block_until_ready(r3) - - self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - bool_event=bool_event, - seed=1234, - x64=x64 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.4] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - for bool_event in [True, False] + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + + r3 = f1(events, 3.) + r3 = jax.block_until_ready(r3) + + self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[-1.], + w_high=[1.], + bool_event=[True, False] + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + bool_event=[True, False], + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap( + lambda e: bm.jitconn.event_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_uniform_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'bool_event={bool_event}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap( - lambda e: taichi_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - ) - - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - testcase_name=f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_high: taichi_mv_prob_uniform( - e, - w_low=0., - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - # print(r1) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu={w_mu}, ' - f'w_sigma={w_sigma}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, ] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - for bool_event in [True, False] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal: shape = {shape}, ' - f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' - f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] + + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_high: bm.jitconn.event_mv_prob_uniform( + e, + w_low=0., + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + # print(r1) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1, ], + w_mu=[0.], + w_sigma=[0.1], + bool_event=[True, False], + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal: shape = {shape}, ' + f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' + f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose = [True, False], + x64 = [True, False], + outdim_parallel = [True, False], + shape = shapes, + prob = [0.1], + bool_event = [True, False], + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose = [True, False], + x64 = [True, False], + outdim_parallel = [True, False], + shape = shapes, + prob = [0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.jit( + jax.grad( + lambda e, w_sigma: bm.jitconn.event_mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() + ) ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.jit( - jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 8a0ae444d..67c18124f 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from functools import partial import jax import jax.numpy as jnp -from absl.testing import parameterized import pytest +from absl.testing import parameterized import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -12,55 +11,38 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - -taichi_mv_prob_homo = bm.jitconn.mv_prob_homo -taichi_mv_prob_uniform = bm.jitconn.mv_prob_uniform -taichi_mv_prob_normal = bm.jitconn.mv_prob_normal +shapes = [(100, 200), (1000, 10)] class Test_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - homo_data=homo_data, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for homo_data in [-1., 1.] - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): - print(f'test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_homo(vector, + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + homo_data=[1.] + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): + print(f'test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -68,152 +50,118 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=123 outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_homo(vector, + r2 = bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - )[0] - ) - r1 = f1(events, weights) - r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: bm.jitconn.mv_prob_homo( + event, data, + conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose + )[0] ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - x64=x64, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] + r1 = f1(events, weights) + r2 = f1(events, weights) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: bm.jitconn.mv_prob_homo( + event, data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum(), + argnums=0 ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_uniform(events, + r1 = f1(events, 1.) + r2 = f1(events, 2.) + + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[-0.1], + w_high=[1.0], + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): + print(f'test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64 = {x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -222,7 +170,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_uniform(events, + r2 = bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -230,45 +178,35 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -277,107 +215,81 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - f1 = jax.grad( - lambda e, w_low, w_high: taichi_mv_prob_uniform( - e, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - - r1 = f1(events, 0., 1.) - r2 = f1(events, 0., 2.) - - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=(f'test_normal, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma},' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - seed=1234 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] + r1 = f1(events) + r2 = f1(events) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + f1 = jax.grad( + lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform( + e, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_normal(events, + + r1 = f1(events, 0., 1.) + r2 = f1(events, 0., 2.) + + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_mu=[0.], + w_sigma=[0.2] + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): + print(f'_test_normal: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu = {w_mu}, ' + f'w_sigma = {w_sigma}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -386,7 +298,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_normal(events, + r2 = bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -394,46 +306,36 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -441,66 +343,54 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - print(r1 - r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64, - testcase_name=f'test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + r1 = f1(events) + r2 = f1(events) + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + print(r1 - r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_sigma: bm.jitconn.mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + r1 = f1(events, 1.) + r2 = f1(events, 2.) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 7fac4452d..f9328906e 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -16,6 +16,7 @@ from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call +from brainpy.errors import PackageMissingError from brainpy._src.dependency_check import (import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops) @@ -485,10 +486,16 @@ def _taichi_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): def register_taichi_aot_mlir_cpu_translation_rule(primitive, cpu_kernel): + if import_taichi(error_if_not_found=False) is None: + raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') + rule = partial(_taichi_mlir_cpu_translation_rule, cpu_kernel) mlir.register_lowering(primitive, rule, platform='cpu') def register_taichi_aot_mlir_gpu_translation_rule(primitive, gpu_kernel): + if import_taichi(error_if_not_found=False) is None: + raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') + rule = partial(_taichi_mlir_gpu_translation_rule, gpu_kernel) mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py index dc093f624..28b80d0f4 100644 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ b/brainpy/_src/math/op_register/tests/test_numba_based.py @@ -1,14 +1,16 @@ -import pytest import jax.core -import brainpy.math as bm +import pytest +import brainpy.math as bm from brainpy._src.dependency_check import import_numba + numba = import_numba(error_if_not_found=False) if numba is None: pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') + @numba.njit(fastmath=True) def numba_event_csrmv(weight, indices, vector, outs): outs.fill(0) @@ -33,5 +35,3 @@ def test_event_ELL(): call(1000) call(100) bm.clear_buffer_memory() - - diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 4db38fbcb..199dce983 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -1,10 +1,10 @@ -import pytest import jax import jax.numpy as jnp +import pytest import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi + ti = import_taichi(error_if_not_found=False) if ti is None: pytest.skip('no taichi', allow_module_level=True) @@ -35,6 +35,7 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) + @ti.kernel def event_ell_gpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), @@ -47,12 +48,13 @@ def event_ell_gpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) + prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu) def test_taichi_op_register(): s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) + indices = bm.random.randint(0, s, (s, 100)) vector = bm.random.rand(s) < 0.1 weight = bm.array([1.0]) diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 51c964b29..5b27b2fd5 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -1,16 +1,15 @@ -import brainpy.math as bm import jax import jax.numpy as jnp -import platform -import pytest + +import brainpy.math as bm +import taichi as ti from brainpy._src.dependency_check import import_taichi ti = import_taichi(error_if_not_found=False) if ti is None: + import pytest pytest.skip('no taichi', allow_module_level=True) -if not platform.platform().startswith('Windows'): - pytest.skip(allow_module_level=True) @ti.func def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: @@ -21,6 +20,7 @@ def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): out[index] += weight_val + @ti.kernel def event_ell_cpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), @@ -34,11 +34,13 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) + prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) + def test_taichi_clean_cache(): s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) + indices = bm.random.randint(0, s, (s, 100)) vector = bm.random.rand(s) < 0.1 weight = bm.array([1.0]) @@ -55,4 +57,4 @@ def test_taichi_clean_cache(): print('kernels: ', bm.check_kernels_count()) -# test_taichi_clean_cache() \ No newline at end of file +# test_taichi_clean_cache() diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d53533247..14256cbce 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,8 +1,8 @@ # from ._coo_mv import * # from ._bsr_mv import * -from ._csr_mv import * -from ._utils import * -from ._bsr_mm import * -from ._jax_prim import * +from .csr_mv import * +from .utils import * +from .bsr_mm import * +from .jax_prim import * diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/bsr_mm.py similarity index 100% rename from brainpy/_src/math/sparse/_bsr_mm.py rename to brainpy/_src/math/sparse/bsr_mm.py diff --git a/brainpy/_src/math/sparse/_bsr_mv.py b/brainpy/_src/math/sparse/bsr_mv.py similarity index 99% rename from brainpy/_src/math/sparse/_bsr_mv.py rename to brainpy/_src/math/sparse/bsr_mv.py index a35895bc1..7dc0b683d 100644 --- a/brainpy/_src/math/sparse/_bsr_mv.py +++ b/brainpy/_src/math/sparse/bsr_mv.py @@ -11,7 +11,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy._src.dependency_check import import_brainpylib_gpu_ops from brainpy.errors import GPUOperatorNotFound diff --git a/brainpy/_src/math/sparse/_coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py similarity index 100% rename from brainpy/_src/math/sparse/_coo_mv.py rename to brainpy/_src/math/sparse/coo_mv.py diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py similarity index 99% rename from brainpy/_src/math/sparse/_csr_mv.py rename to brainpy/_src/math/sparse/csr_mv.py index 42969f435..6eaf6b791 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/csr_mv.py @@ -13,7 +13,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy.errors import PackageMissingError ti = import_taichi(error_if_not_found=False) @@ -108,6 +108,7 @@ def raw_csrmv_taichi( ): if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + out_shape = shape[1] if transpose else shape[0] if data.shape[0] != 1: if bm.get_platform() == 'gpu': diff --git a/brainpy/_src/math/sparse/_jax_prim.py b/brainpy/_src/math/sparse/jax_prim.py similarity index 100% rename from brainpy/_src/math/sparse/_jax_prim.py rename to brainpy/_src/math/sparse/jax_prim.py diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index ec448e658..40bcbb706 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -3,12 +3,11 @@ from functools import partial import jax +import pytest from absl.testing import parameterized -import pytest import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -25,7 +24,6 @@ def func(*args, **kwargs): return func - def compare_with_nan_tolerance(a, b, tol=1e-8): """ Compare two arrays with tolerance for NaN values. @@ -56,8 +54,6 @@ def compare_with_nan_tolerance(a, b, tol=1e-8): return bm.allclose(a_non_nan, b_non_nan, atol=tol) -taichi_csr_matvec = bm.sparse.csrmv - class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_csrmv_taichi, self).__init__(*args, **kwargs) @@ -67,8 +63,8 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] + shape=[(200, 200), (10, 1000)], + homo_data=[1.] ) def test_homo(self, transpose, shape, homo_data): print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') @@ -87,15 +83,15 @@ def test_homo(self, transpose, shape, homo_data): dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = bm.sparse.csrmv(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)], - v=[-1., 0., 1.] + shape=[(200, 200), (100, 1000)], + v=[1.] ) def test_homo_vmap(self, transpose, shape, v): print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') @@ -113,7 +109,7 @@ def test_homo_vmap(self, transpose, shape, v): dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) - f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(dense_data) r2 = jax.vmap(f2)(homo_data) @@ -123,8 +119,8 @@ def test_homo_vmap(self, transpose, shape, v): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] + shape=[(200, 200), (10, 1000)], + homo_data=[1.] ) def test_homo_grad(self, transpose, shape, homo_data): print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') @@ -148,8 +144,7 @@ def test_homo_grad(self, transpose, shape, homo_data): ((dense * a) @ vector).sum()), argnums=0) r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(bm.sparse.csrmv))(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -158,8 +153,8 @@ def test_homo_grad(self, transpose, shape, homo_data): dense_data = dense * homo_data dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) r3 = dense_f2(vector) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r4 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( + bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -168,8 +163,8 @@ def test_homo_grad(self, transpose, shape, homo_data): ((dense * a) @ v).sum()), argnums=(0, 1)) r5 = dense_f3(homo_data, vector) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( + bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) @@ -177,7 +172,7 @@ def test_homo_grad(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (2, 2000)], + shape=[(200, 200), (2, 2000)], ) def test_heter(self, transpose, shape): print(f'test_homo: transpose = {transpose} shape = {shape}') @@ -196,7 +191,7 @@ def test_heter(self, transpose, shape): dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = bm.sparse.csrmv(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(compare_with_nan_tolerance(r1, r2)) @@ -204,7 +199,7 @@ def test_heter(self, transpose, shape): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + shape=[(200, 200), (2, 2000)] ) def test_heter_vmap(self, transpose, shape): rng = bm.random.RandomState(seed=seed) @@ -222,7 +217,7 @@ def test_heter_vmap(self, transpose, shape): shape=shape))(heter_data) f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) - f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(dense_data) r2 = jax.vmap(f2)(heter_data) @@ -230,7 +225,7 @@ def test_heter_vmap(self, transpose, shape): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + shape=[(200, 200), (2, 2000)] ) def test_heter_grad(self, transpose, shape): rng = bm.random.RandomState(seed=seed) @@ -246,11 +241,10 @@ def test_heter_grad(self, transpose, shape): vector = bm.as_jax(vector) # grad 'data' - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), - argnums=0) - csr_f1 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), + dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), argnums=0) + csr_f1 = jax.grad(lambda a: bm.sparse.csrmv(a, indices, indptr, vector, + shape=shape, + transpose=transpose).sum(), argnums=0) r1 = csr_f1(heter_data) r2 = dense_f1(dense_data) @@ -260,12 +254,11 @@ def test_heter_grad(self, transpose, shape): self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), - argnums=0) - csr_f2 = jax.grad(lambda v: taichi_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0) + csr_f2 = jax.grad(lambda v: bm.sparse.csrmv(heter_data, indices, indptr, v, + shape=shape, + transpose=transpose).sum(), + argnums=0) r3 = dense_f2(vector) r4 = csr_f2(vector) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/_src/math/sparse/_utils.py b/brainpy/_src/math/sparse/utils.py similarity index 100% rename from brainpy/_src/math/sparse/_utils.py rename to brainpy/_src/math/sparse/utils.py diff --git a/requirements-dev-raw.txt b/requirements-dev-raw.txt deleted file mode 100644 index 99361efa9..000000000 --- a/requirements-dev-raw.txt +++ /dev/null @@ -1,12 +0,0 @@ -numpy -jax -jaxlib -matplotlib -msgpack -tqdm -pathos - - -# test requirements -pytest -absl-py