Skip to content

Commit

Permalink
Merge branch 'fix-ad-support' into braintaichi-op
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Nov 23, 2024
2 parents dc63758 + 2f9952f commit 7e4575c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ jobs:
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
pip install jax==0.4.30
pip install jaxlib==0.4.30
- name: Test with pytest
run: |
cd brainpy
Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/dnn/tests/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import brainpy.math as bm



class Test_Activation(parameterized.TestCase):

@parameterized.product(
Expand Down
5 changes: 5 additions & 0 deletions brainpy/_src/dnn/tests/test_conv_layers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# -*- coding: utf-8 -*-
import platform

import jax.numpy as jnp
import pytest
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

if platform.system() == 'Darwin':
pytest.skip('skip Mac OS', allow_module_level=True)


class TestConv(parameterized.TestCase):
def test_Conv2D_img(self):
Expand Down
9 changes: 7 additions & 2 deletions brainpy/_src/math/op_register/ad_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
return val_out, functools.reduce(_add_tangents,
try:
return val_out, functools.reduce(_add_tangents,
tangents_out,
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
tree_util.tree_map(lambda a: ad.Zero.from_primal_value(a), val_out))
except:
return val_out, functools.reduce(_add_tangents,
tangents_out,
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))


def _add_tangents(xs, ys):
Expand Down

0 comments on commit 7e4575c

Please sign in to comment.