-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from mfschubert/jit
Add experimental jit-, vmap-, and jacrev-compatible wrapper
- Loading branch information
Showing
9 changed files
with
727 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
"""Defines a jax wrapper for autograd-differentiable functions.""" | ||
|
||
import functools | ||
from typing import Any, Callable, List, Tuple, Union | ||
|
||
import autograd | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as onp | ||
from jax import tree_util | ||
|
||
from agjax import utils | ||
|
||
PyTree = Any | ||
|
||
_FORWARD_STAGE = "fwd" | ||
_BACKWARD_STAGE = "bwd" | ||
|
||
|
||
def wrap_for_jax( | ||
fn: Callable[[Any], Any], | ||
result_shape_dtypes: Any, | ||
nondiff_argnums: Union[int, Tuple[int, ...]] = (), | ||
nondiff_outputnums: Union[int, Tuple[int, ...]] = (), | ||
) -> Callable[[Any], Any]: | ||
"""Wraps `fn` so that it can be differentiated by jax. | ||
The wrapped function is suitable for jax transformations such as `grad`, `jit`, | ||
`vmap`, and `jacrev`, which is achieved using `jax.pure_callback`. | ||
Arguments to `fn` must be convertible to jax types, as must all outputs. The | ||
arguments to the wrapped function should be jax types, and the outputs will be | ||
jax types. | ||
Arguments which need not be differentiated with respect to may be specified in | ||
`nondiff_argnums`, while outputs that need not be differentiated may be specified | ||
in `nondiff_outputnums`. | ||
Args: | ||
fn: The autograd-differentiable function. | ||
result_shape_dtypes: A pytree matching the jax-converted output of `fn`. | ||
Specifically, the pytree structure, leaf shapes, and datatypes must match. | ||
nondiff_argnums: The arguments that cannot be differentiated with respect to. | ||
nondiff_outputnums: The outputs that cannot be differentiated. | ||
Returns: | ||
The wrapped function. | ||
""" | ||
_nondiff_argnums, _ = utils.ensure_tuple(nondiff_argnums) | ||
_nondiff_outputnums, _ = utils.ensure_tuple(nondiff_outputnums) | ||
del nondiff_argnums, nondiff_outputnums | ||
|
||
split_args_fn = functools.partial(utils.split, idx=_nondiff_argnums) | ||
merge_args_fn = functools.partial(utils.merge, idx=_nondiff_argnums) | ||
split_outputs_fn = functools.partial(utils.split, idx=_nondiff_outputnums) | ||
merge_outputs_fn = functools.partial(utils.merge, idx=_nondiff_outputnums) | ||
|
||
# Vjp functions created in the "forward stage" of the calculation, and stored in | ||
# the `vjp_fns` list. When the calculation switches from the backward to the | ||
# forward stage, the list of vjp functions is cleared. | ||
vjp_fns: List[jax.tree_util.Partial] = [] | ||
stage = _BACKWARD_STAGE | ||
|
||
@jax.custom_vjp # type: ignore[misc] | ||
def _fn(*args: Any) -> Any: | ||
utils.validate_nondiff_argnums_for_args(_nondiff_argnums, args) | ||
outputs = jax.pure_callback( | ||
lambda *args: utils.to_jax(fn(*utils.to_numpy(args))), | ||
result_shape_dtypes, | ||
*args, | ||
) | ||
utils.validate_nondiff_outputnums_for_outputs(_nondiff_outputnums, outputs) | ||
return outputs | ||
|
||
def _fwd_fn(*args: Any) -> Any: | ||
def make_vjp(*args: Any) -> Any: | ||
nonlocal stage | ||
if stage == _BACKWARD_STAGE: | ||
vjp_fns.clear() | ||
stage = _FORWARD_STAGE | ||
|
||
# Variables updated nonlocally where `fn` is evaluated. | ||
is_tuple_outputs: bool = None # type: ignore[assignment] | ||
nondiff_outputs: Tuple[Any, ...] = None # type: ignore[assignment] | ||
diff_outputs_treedef: jax.tree_util.PyTreeDef = None | ||
|
||
def _tuple_fn(*args: Any) -> onp.ndarray: | ||
nonlocal is_tuple_outputs | ||
nonlocal nondiff_outputs | ||
nonlocal diff_outputs_treedef | ||
|
||
utils.validate_nondiff_argnums_for_args(_nondiff_argnums, args) | ||
outputs = fn(*args) | ||
utils.validate_nondiff_outputnums_for_outputs( | ||
_nondiff_outputnums, outputs | ||
) | ||
|
||
outputs, is_tuple_outputs = utils.ensure_tuple(outputs) | ||
nondiff_outputs, diff_outputs = split_outputs_fn(outputs) | ||
nondiff_outputs = utils.arraybox_to_numpy(nondiff_outputs) | ||
diff_outputs_leaves, diff_outputs_treedef = jax.tree_util.tree_flatten( | ||
diff_outputs | ||
) | ||
return autograd.builtins.tuple(tuple(diff_outputs_leaves)) | ||
|
||
args = utils.to_numpy(args) | ||
diff_argnums = tuple( | ||
i for i in range(len(args)) if i not in _nondiff_argnums | ||
) | ||
tuple_vjp_fn, diff_outputs_leaves = autograd.make_vjp( | ||
_tuple_fn, argnum=diff_argnums | ||
)(*args) | ||
diff_outputs = jax.tree_util.tree_unflatten( | ||
diff_outputs_treedef, diff_outputs_leaves | ||
) | ||
outputs = utils.to_jax(merge_outputs_fn(nondiff_outputs, diff_outputs)) | ||
outputs = outputs if is_tuple_outputs else outputs[0] | ||
|
||
def _vjp_fn(*diff_outputs: Any) -> Any: | ||
diff_outputs_leaves = jax.tree_util.tree_leaves(diff_outputs) | ||
grad = tuple_vjp_fn(utils.to_numpy(diff_outputs_leaves)) | ||
return utils.to_jax(grad) | ||
|
||
key = len(vjp_fns) | ||
vjp_fns.append(tree_util.Partial(_vjp_fn)) | ||
return outputs, jnp.asarray(key) | ||
|
||
outputs, key = jax.pure_callback( | ||
make_vjp, | ||
(result_shape_dtypes, jnp.asarray(0)), | ||
*args, | ||
) | ||
return outputs, (args, key) | ||
|
||
def _bwd_fn(*bwd_args: Any) -> Any: | ||
def _pure_fn(key: jnp.ndarray, tangents: Tuple[Any, ...]) -> Any: | ||
nonlocal stage | ||
stage = _BACKWARD_STAGE | ||
vjp_fn = vjp_fns[int(key)] | ||
return utils.to_jax(vjp_fn(utils.to_numpy(*tangents))) | ||
|
||
(args, key), *tangents = bwd_args | ||
_, diff_args = split_args_fn(args) | ||
result_shape_dtypes = utils.to_jax(diff_args) | ||
grads = jax.pure_callback(_pure_fn, result_shape_dtypes, key, tangents) | ||
return merge_args_fn([None] * len(_nondiff_argnums), grads) | ||
|
||
_fn.defvjp(_fwd_fn, _bwd_fn) | ||
|
||
return _fn # type: ignore[no-any-return] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
"""Defines a utility functions for jax-autograd wrappers.""" | ||
|
||
from typing import Any, Sequence, Tuple | ||
|
||
import autograd.numpy as npa | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as onp | ||
|
||
PyTree = Any | ||
|
||
|
||
class WrappedValue: | ||
"""Wraps a value treated as an auxilliary quantity of a pytree node.""" | ||
|
||
def __init__(self, value: Any) -> None: | ||
self.value = value | ||
|
||
def __repr__(self) -> str: | ||
return f"_WrappedValue({self.value})" | ||
|
||
|
||
jax.tree_util.register_pytree_node( | ||
WrappedValue, | ||
flatten_func=lambda w: ((), (w.value,)), | ||
unflatten_func=lambda v, _: WrappedValue(*v), | ||
) | ||
|
||
|
||
def validate_nondiff_outputnums_for_outputs( | ||
nondiff_outputnums: Sequence[int], | ||
maybe_tuple_outputs: Any, | ||
) -> None: | ||
"""Validates that `nondiff_outputnums` is compatible with a `outputs`.""" | ||
outputs_length = ( | ||
len(maybe_tuple_outputs) if isinstance(maybe_tuple_outputs, tuple) else 1 | ||
) | ||
validate_idx_for_sequence_len(nondiff_outputnums, outputs_length) | ||
if outputs_length <= len(nondiff_outputnums): | ||
raise ValueError( | ||
f"At least one differentiable output is required, but got " | ||
f"`nondiff_outputnums` of {nondiff_outputnums} when `fn` " | ||
f"has {outputs_length} output(s)." | ||
) | ||
|
||
|
||
def validate_nondiff_argnums_for_args( | ||
nondiff_argnums: Sequence[int], | ||
args: Tuple[Any, ...], | ||
) -> None: | ||
"""Validates that `nondiff_argnums` is compatible with a `args`.""" | ||
validate_idx_for_sequence_len(nondiff_argnums, len(args)) | ||
if len(args) <= len(nondiff_argnums): | ||
raise ValueError( | ||
f"At least argument must be differentiated with respect to, but got " | ||
f"`nondiff_argnums` of {nondiff_argnums} when `fn` has {len(args)} " | ||
f"arguments(s)." | ||
) | ||
|
||
|
||
def validate_idx_for_sequence_len(idx: Sequence[int], sequence_length: int) -> None: | ||
"""Validates that `idx` is compatible with a sequence length.""" | ||
if not all(i in range(-sequence_length, sequence_length) for i in idx): | ||
raise ValueError( | ||
f"Found out of bounds values in `idx`, got {idx} when " | ||
f"`sequence_length` is {sequence_length}." | ||
) | ||
positive_idx = [i % sequence_length for i in idx] | ||
if len(positive_idx) != len(set(positive_idx)): | ||
raise ValueError( | ||
f"Found duplicate values in `idx`, got {idx} when " | ||
f"`sequence_length` is {sequence_length}." | ||
) | ||
|
||
|
||
def split( | ||
a: Tuple[Any, ...], | ||
idx: Tuple[int, ...], | ||
) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]: | ||
"""Splits the sequence `a` into two sequences.""" | ||
validate_idx_for_sequence_len(idx, len(a)) | ||
return ( | ||
tuple([a[i] for i in idx]), | ||
tuple([a[i] for i in range(len(a)) if i not in idx]), | ||
) | ||
|
||
|
||
def merge( | ||
a: Sequence[Any], | ||
b: Sequence[Any], | ||
idx: Sequence[int], | ||
) -> Tuple[Any, ...]: | ||
"""Merges the sequences `a` and `b`, undoing a `_split` operation.""" | ||
validate_idx_for_sequence_len(idx, len(a) + len(b)) | ||
positive_idx = [i % (len(a) + len(b)) for i in idx] | ||
iter_a = iter(a) | ||
iter_b = iter(b) | ||
return tuple( | ||
[ | ||
next(iter_a) if i in positive_idx else next(iter_b) | ||
for i in range(len(a) + len(b)) | ||
] | ||
) | ||
|
||
|
||
def to_jax(tree: PyTree) -> PyTree: | ||
"""Converts leaves of a pytree to jax arrays.""" | ||
return jax.tree_util.tree_map(jnp.asarray, tree) | ||
|
||
|
||
def to_numpy(tree: PyTree) -> PyTree: | ||
"""Converts leaves of a pytree to numpy arrays.""" | ||
return jax.tree_util.tree_map(onp.asarray, tree) | ||
|
||
|
||
def arraybox_to_numpy(tree: PyTree) -> PyTree: | ||
"""Converts `ArrayBox` leaves of a pytree to numpy arrays.""" | ||
return jax.tree_util.tree_map( | ||
lambda x: x._value if isinstance(x, npa.numpy_boxes.ArrayBox) else x, | ||
tree, | ||
) | ||
|
||
|
||
def ensure_tuple(xs: Any) -> Tuple[Any, bool]: | ||
"""Returns `(xs, True)` if `xs` is a tuple, and `((xs,), False)` otherwise.""" | ||
is_tuple = isinstance(xs, tuple) | ||
return (xs if is_tuple else (xs,)), is_tuple |
Oops, something went wrong.