diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4dc23b35..f5992766 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -79,6 +79,8 @@ jobs: if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install + pip install jax==0.4.30 + pip install jaxlib==0.4.30 - name: Test with pytest run: | cd brainpy diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index 7a0fa57a..51470641 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -4,6 +4,7 @@ import brainpy.math as bm + class Test_Activation(parameterized.TestCase): @parameterized.product( diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 05f52362..af38a355 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,12 +1,17 @@ # -*- coding: utf-8 -*- +import platform import jax.numpy as jnp +import pytest from absl.testing import absltest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm +if platform.system() == 'Darwin': + pytest.skip('skip Mac OS', allow_module_level=True) + class TestConv(parameterized.TestCase): def test_Conv2D_img(self): diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py index 342093ea..54a3c9be 100644 --- a/brainpy/_src/math/op_register/ad_support.py +++ b/brainpy/_src/math/op_register/ad_support.py @@ -41,9 +41,14 @@ def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params): r = tuple(rule(t, *primals, **params)) tangents_out.append(r) assert tree_util.tree_structure(r) == tree - return val_out, functools.reduce(_add_tangents, + try: + return val_out, functools.reduce(_add_tangents, tangents_out, - tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out)) + tree_util.tree_map(lambda a: ad.Zero.from_primal_value(a), val_out)) + except: + return val_out, functools.reduce(_add_tangents, + tangents_out, + tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out)) def _add_tangents(xs, ys):