Skip to content

Commit

Permalink
[test] remove test skip on macos (#597)
Browse files Browse the repository at this point in the history
* [test] remove test skip on macos, since brainpylib supports taichi interface on macos

* update

* updates
  • Loading branch information
chaoming0625 authored Jan 20, 2024
1 parent c2f2db9 commit bc5aa72
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 491 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/synapses/delay_couplings.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self.output_var = var_to_output

# Connection matrix
self.conn_mat = bm.asarray(conn_mat)
self.conn_mat = conn_mat
if self.conn_mat.shape != required_shape:
raise ValueError(f'we expect the structural connection matrix has the shape of '
f'(pre.num, post.num), i.e., {required_shape}, '
Expand Down
210 changes: 0 additions & 210 deletions brainpy/_src/math/event/tests/test_event_csrmv_taichi.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
# -*- coding: utf-8 -*-


import sys
from functools import partial

import jax
import pytest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

# pytestmark = pytest.mark.skip(reason="Skipped due to pytest limitations, manual execution required for testing.")

is_manual_test = False
if sys.platform.startswith('darwin') and not is_manual_test:
pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)

# bm.set_platform('cpu')

seed = 1234


Expand All @@ -38,206 +28,6 @@ def func(*args, **kwargs):
return func


# ### MANUAL TESTS ###

# transposes = [True, False]
# shapes = [(100, 200),
# (200, 200),
# (200, 100),
# (10, 1000),
# # (2, 10000),
# # (1000, 10),
# # (10000, 2)
# ]
# homo_datas = [-1., 0., 1.]

# def test_homo(shape, transpose, homo_data):
# print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
# rng = bm.random.RandomState()
# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
# events = rng.random(shape[0] if transpose else shape[1]) < 0.1
# heter_data = bm.ones(indices.shape) * homo_data

# r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
# r2 = bm.event.csrmv_taichi(homo_data, indices, indptr, events, shape=shape, transpose=transpose)

# assert (bm.allclose(r1, r2[0]))

# bm.clear_buffer_memory()


# def test_homo_vmap(shape, transpose, homo_data):
# print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')

# rng = bm.random.RandomState()
# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')

# # vmap 'data'
# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
# f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events,
# shape=shape, transpose=transpose))
# f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events,
# shape=shape, transpose=transpose))
# vmap_data = bm.as_jax([homo_data] * 10)
# assert(bm.allclose(f1(vmap_data), f2(vmap_data)[0]))

# # vmap 'events'
# f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr,
# shape=shape, transpose=transpose))
# f4 = jax.vmap(partial(bm.event.csrmv_taichi, homo_data, indices, indptr,
# shape=shape, transpose=transpose))
# vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
# assert(bm.allclose(f3(vmap_data), f4(vmap_data)[0]))

# # vmap 'data' and 'events'
# f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose))
# f6 = jax.vmap(lambda dd, ee: bm.event.csrmv_taichi(dd, indices, indptr, ee, shape=shape, transpose=transpose))

# vmap_data1 = bm.as_jax([homo_data] * 10)
# vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
# assert(bm.allclose(f5(vmap_data1, vmap_data2),
# f6(vmap_data1, vmap_data2)[0]))

# bm.clear_buffer_memory()


# def test_homo_grad(shape, transpose, homo_data):
# print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')

# rng = bm.random.RandomState()
# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
# indices = bm.as_jax(indices)
# indptr = bm.as_jax(indptr)
# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
# dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)

# # grad 'data'
# r1 = jax.grad(sum_op(bm.event.csrmv))(
# homo_data, indices, indptr, events, shape=shape, transpose=transpose)
# r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))(
# homo_data, indices, indptr, events, shape=shape, transpose=transpose)
# assert(bm.allclose(r1, r2))

# # grad 'events'
# r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(
# homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
# r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)(
# homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
# assert(bm.allclose(r3, r4))

# bm.clear_buffer_memory()


# def test_heter(shape, transpose):
# print(f'test_heter: shape = {shape}, transpose = {transpose}')
# rng = bm.random.RandomState()
# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
# indices = bm.as_jax(indices)
# indptr = bm.as_jax(indptr)
# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
# heter_data = bm.as_jax(rng.random(indices.shape))

# r1 = bm.event.csrmv(heter_data, indices, indptr, events,
# shape=shape, transpose=transpose)
# r2 = bm.event.csrmv_taichi(heter_data, indices, indptr, events,
# shape=shape, transpose=transpose)

# assert(bm.allclose(r1, r2[0]))

# bm.clear_buffer_memory()


# def test_heter_vmap(shape, transpose):
# print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}')

# rng = bm.random.RandomState()
# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
# indices = bm.as_jax(indices)
# indptr = bm.as_jax(indptr)

# # vmap 'data'
# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
# f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events,
# shape=shape, transpose=transpose))
# f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events,
# shape=shape, transpose=transpose))
# vmap_data = bm.as_jax(rng.random((10, indices.shape[0])))
# assert(bm.allclose(f1(vmap_data), f2(vmap_data)[0]))

# # vmap 'events'
# data = bm.as_jax(rng.random(indices.shape))
# f3 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr,
# shape=shape, transpose=transpose))
# f4 = jax.vmap(partial(bm.event.csrmv_taichi, data, indices, indptr,
# shape=shape, transpose=transpose))
# vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
# assert(bm.allclose(f3(vmap_data), f4(vmap_data)[0]))

# # vmap 'data' and 'events'
# f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee,
# shape=shape, transpose=transpose))
# f6 = jax.vmap(lambda dd, ee: bm.event.csrmv_taichi(dd, indices, indptr, ee,
# shape=shape, transpose=transpose))
# vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0])))
# vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
# assert(bm.allclose(f5(vmap_data1, vmap_data2),
# f6(vmap_data1, vmap_data2)[0]))

# bm.clear_buffer_memory()


# def test_heter_grad(shape, transpose):
# print(f'test_heter_grad: shape = {shape}, transpose = {transpose}')

# rng = bm.random.RandomState()
# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
# indices = bm.as_jax(indices)
# indptr = bm.as_jax(indptr)
# events = rng.random(shape[0] if transpose else shape[1]) < 0.1
# events = bm.as_jax(events)
# dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)

# # grad 'data'
# data = bm.as_jax(rng.random(indices.shape))
# r1 = jax.grad(sum_op(bm.event.csrmv))(
# data, indices, indptr, events, shape=shape, transpose=transpose)
# r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))(
# data, indices, indptr, events, shape=shape, transpose=transpose)
# assert(bm.allclose(r1, r2))

# # grad 'events'
# r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(
# data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
# r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)(
# data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
# assert(bm.allclose(r3, r4))

# r5 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))(
# data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
# r6 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=(0, 3))(
# data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
# assert(bm.allclose(r5[0], r6[0]))
# assert(bm.allclose(r5[1], r6[1]))

# bm.clear_buffer_memory()

# def test_all():
# for transpose in transposes:
# for shape in shapes:
# for homo_data in homo_datas:
# test_homo(shape, transpose, homo_data)
# test_homo_vmap(shape, transpose, homo_data)
# test_homo_grad(shape, transpose, homo_data)

# for transpose in transposes:
# for shape in shapes:
# test_heter(shape, transpose)
# test_heter_vmap(shape, transpose)
# test_heter_grad(shape, transpose)
# test_all()


### PYTEST
class Test_event_csr_matvec_taichi(parameterized.TestCase):
def __init__(self, *args, platform='cpu', **kwargs):
super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs)
Expand Down
6 changes: 0 additions & 6 deletions brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
# -*- coding: utf-8 -*-

import sys

import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

import brainpy.math as bm

is_manual_test = False
if sys.platform.startswith('darwin') and not is_manual_test:
pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)

shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)]
shapes = [(100, 200), (2, 1000), (1000, 2)]

Expand Down
Loading

0 comments on commit bc5aa72

Please sign in to comment.