Skip to content

Commit

Permalink
fix ad compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Nov 13, 2024
1 parent 2a5adea commit e539c21
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions brainpy/_src/math/op_register/ad_support.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,57 @@
import functools
from functools import partial

import jax
from jax import tree_util
from jax.core import Primitive
from jax.interpreters import ad

__all__ = [
'defjvp',
'defjvp',
]


def defjvp(primitive, *jvp_rules):
"""Define JVP rules for any JAX primitive.
"""Define JVP rules for any JAX primitive.
This function is similar to ``jax.interpreters.ad.defjvp``.
However, the JAX one only supports primitive with ``multiple_results=False``.
``brainpy.math.defjvp`` enables to define the independent JVP rule for
each input parameter no matter ``multiple_results=False/True``.
This function is similar to ``jax.interpreters.ad.defjvp``.
However, the JAX one only supports primitive with ``multiple_results=False``.
``brainpy.math.defjvp`` enables to define the independent JVP rule for
each input parameter no matter ``multiple_results=False/True``.
For examples, please see ``test_ad_support.py``.
For examples, please see ``test_ad_support.py``.
Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
"""
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
else:
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
"""
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
else:
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)


def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
assert primitive.multiple_results
val_out = tuple(primitive.bind(*primals, **params))
tree = tree_util.tree_structure(val_out)
tangents_out = []
for rule, t in zip(jvp_rules, tangents):
if rule is not None and type(t) is not ad.Zero:
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
return val_out, functools.reduce(_add_tangents,
tangents_out,
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
assert primitive.multiple_results
val_out = tuple(primitive.bind(*primals, **params))
tree = tree_util.tree_structure(val_out)
tangents_out = []
for rule, t in zip(jvp_rules, tangents):
if rule is not None and type(t) is not ad.Zero:
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
return val_out, functools.reduce(
_add_tangents,
tangents_out,
tree_util.tree_map(
# compatible with JAX 0.4.34
lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
val_out
)
)


def _add_tangents(xs, ys):
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))

return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))

0 comments on commit e539c21

Please sign in to comment.