Skip to content

Commit

Permalink
Merge pull request #23 from mfschubert/jit
Browse files Browse the repository at this point in the history
Add experimental jit-, vmap-, and jacrev-compatible wrapper
  • Loading branch information
mfschubert authored Feb 20, 2024
2 parents d23d2ce + 6f7c635 commit a6bc34b
Show file tree
Hide file tree
Showing 9 changed files with 727 additions and 182 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
rev: "v1.0.1"
hooks:
- id: mypy
exclude: ^(docs/|example-plugin/|tests/fixtures)
exclude: ^(docs/|example-plugin/|tests/|fixtures/)
additional_dependencies:
- "pydantic"

Expand Down
Empty file added agjax/experimental/__init__.py
Empty file.
150 changes: 150 additions & 0 deletions agjax/experimental/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Defines a jax wrapper for autograd-differentiable functions."""

import functools
from typing import Any, Callable, List, Tuple, Union

import autograd
import jax
import jax.numpy as jnp
import numpy as onp
from jax import tree_util

from agjax import utils

PyTree = Any

_FORWARD_STAGE = "fwd"
_BACKWARD_STAGE = "bwd"


def wrap_for_jax(
fn: Callable[[Any], Any],
result_shape_dtypes: Any,
nondiff_argnums: Union[int, Tuple[int, ...]] = (),
nondiff_outputnums: Union[int, Tuple[int, ...]] = (),
) -> Callable[[Any], Any]:
"""Wraps `fn` so that it can be differentiated by jax.
The wrapped function is suitable for jax transformations such as `grad`, `jit`,
`vmap`, and `jacrev`, which is achieved using `jax.pure_callback`.
Arguments to `fn` must be convertible to jax types, as must all outputs. The
arguments to the wrapped function should be jax types, and the outputs will be
jax types.
Arguments which need not be differentiated with respect to may be specified in
`nondiff_argnums`, while outputs that need not be differentiated may be specified
in `nondiff_outputnums`.
Args:
fn: The autograd-differentiable function.
result_shape_dtypes: A pytree matching the jax-converted output of `fn`.
Specifically, the pytree structure, leaf shapes, and datatypes must match.
nondiff_argnums: The arguments that cannot be differentiated with respect to.
nondiff_outputnums: The outputs that cannot be differentiated.
Returns:
The wrapped function.
"""
_nondiff_argnums, _ = utils.ensure_tuple(nondiff_argnums)
_nondiff_outputnums, _ = utils.ensure_tuple(nondiff_outputnums)
del nondiff_argnums, nondiff_outputnums

split_args_fn = functools.partial(utils.split, idx=_nondiff_argnums)
merge_args_fn = functools.partial(utils.merge, idx=_nondiff_argnums)
split_outputs_fn = functools.partial(utils.split, idx=_nondiff_outputnums)
merge_outputs_fn = functools.partial(utils.merge, idx=_nondiff_outputnums)

# Vjp functions created in the "forward stage" of the calculation, and stored in
# the `vjp_fns` list. When the calculation switches from the backward to the
# forward stage, the list of vjp functions is cleared.
vjp_fns: List[jax.tree_util.Partial] = []
stage = _BACKWARD_STAGE

@jax.custom_vjp # type: ignore[misc]
def _fn(*args: Any) -> Any:
utils.validate_nondiff_argnums_for_args(_nondiff_argnums, args)
outputs = jax.pure_callback(
lambda *args: utils.to_jax(fn(*utils.to_numpy(args))),
result_shape_dtypes,
*args,
)
utils.validate_nondiff_outputnums_for_outputs(_nondiff_outputnums, outputs)
return outputs

def _fwd_fn(*args: Any) -> Any:
def make_vjp(*args: Any) -> Any:
nonlocal stage
if stage == _BACKWARD_STAGE:
vjp_fns.clear()
stage = _FORWARD_STAGE

# Variables updated nonlocally where `fn` is evaluated.
is_tuple_outputs: bool = None # type: ignore[assignment]
nondiff_outputs: Tuple[Any, ...] = None # type: ignore[assignment]
diff_outputs_treedef: jax.tree_util.PyTreeDef = None

def _tuple_fn(*args: Any) -> onp.ndarray:
nonlocal is_tuple_outputs
nonlocal nondiff_outputs
nonlocal diff_outputs_treedef

utils.validate_nondiff_argnums_for_args(_nondiff_argnums, args)
outputs = fn(*args)
utils.validate_nondiff_outputnums_for_outputs(
_nondiff_outputnums, outputs
)

outputs, is_tuple_outputs = utils.ensure_tuple(outputs)
nondiff_outputs, diff_outputs = split_outputs_fn(outputs)
nondiff_outputs = utils.arraybox_to_numpy(nondiff_outputs)
diff_outputs_leaves, diff_outputs_treedef = jax.tree_util.tree_flatten(
diff_outputs
)
return autograd.builtins.tuple(tuple(diff_outputs_leaves))

args = utils.to_numpy(args)
diff_argnums = tuple(
i for i in range(len(args)) if i not in _nondiff_argnums
)
tuple_vjp_fn, diff_outputs_leaves = autograd.make_vjp(
_tuple_fn, argnum=diff_argnums
)(*args)
diff_outputs = jax.tree_util.tree_unflatten(
diff_outputs_treedef, diff_outputs_leaves
)
outputs = utils.to_jax(merge_outputs_fn(nondiff_outputs, diff_outputs))
outputs = outputs if is_tuple_outputs else outputs[0]

def _vjp_fn(*diff_outputs: Any) -> Any:
diff_outputs_leaves = jax.tree_util.tree_leaves(diff_outputs)
grad = tuple_vjp_fn(utils.to_numpy(diff_outputs_leaves))
return utils.to_jax(grad)

key = len(vjp_fns)
vjp_fns.append(tree_util.Partial(_vjp_fn))
return outputs, jnp.asarray(key)

outputs, key = jax.pure_callback(
make_vjp,
(result_shape_dtypes, jnp.asarray(0)),
*args,
)
return outputs, (args, key)

def _bwd_fn(*bwd_args: Any) -> Any:
def _pure_fn(key: jnp.ndarray, tangents: Tuple[Any, ...]) -> Any:
nonlocal stage
stage = _BACKWARD_STAGE
vjp_fn = vjp_fns[int(key)]
return utils.to_jax(vjp_fn(utils.to_numpy(*tangents)))

(args, key), *tangents = bwd_args
_, diff_args = split_args_fn(args)
result_shape_dtypes = utils.to_jax(diff_args)
grads = jax.pure_callback(_pure_fn, result_shape_dtypes, key, tangents)
return merge_args_fn([None] * len(_nondiff_argnums), grads)

_fn.defvjp(_fwd_fn, _bwd_fn)

return _fn # type: ignore[no-any-return]
127 changes: 127 additions & 0 deletions agjax/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Defines a utility functions for jax-autograd wrappers."""

from typing import Any, Sequence, Tuple

import autograd.numpy as npa
import jax
import jax.numpy as jnp
import numpy as onp

PyTree = Any


class WrappedValue:
"""Wraps a value treated as an auxilliary quantity of a pytree node."""

def __init__(self, value: Any) -> None:
self.value = value

def __repr__(self) -> str:
return f"_WrappedValue({self.value})"


jax.tree_util.register_pytree_node(
WrappedValue,
flatten_func=lambda w: ((), (w.value,)),
unflatten_func=lambda v, _: WrappedValue(*v),
)


def validate_nondiff_outputnums_for_outputs(
nondiff_outputnums: Sequence[int],
maybe_tuple_outputs: Any,
) -> None:
"""Validates that `nondiff_outputnums` is compatible with a `outputs`."""
outputs_length = (
len(maybe_tuple_outputs) if isinstance(maybe_tuple_outputs, tuple) else 1
)
validate_idx_for_sequence_len(nondiff_outputnums, outputs_length)
if outputs_length <= len(nondiff_outputnums):
raise ValueError(
f"At least one differentiable output is required, but got "
f"`nondiff_outputnums` of {nondiff_outputnums} when `fn` "
f"has {outputs_length} output(s)."
)


def validate_nondiff_argnums_for_args(
nondiff_argnums: Sequence[int],
args: Tuple[Any, ...],
) -> None:
"""Validates that `nondiff_argnums` is compatible with a `args`."""
validate_idx_for_sequence_len(nondiff_argnums, len(args))
if len(args) <= len(nondiff_argnums):
raise ValueError(
f"At least argument must be differentiated with respect to, but got "
f"`nondiff_argnums` of {nondiff_argnums} when `fn` has {len(args)} "
f"arguments(s)."
)


def validate_idx_for_sequence_len(idx: Sequence[int], sequence_length: int) -> None:
"""Validates that `idx` is compatible with a sequence length."""
if not all(i in range(-sequence_length, sequence_length) for i in idx):
raise ValueError(
f"Found out of bounds values in `idx`, got {idx} when "
f"`sequence_length` is {sequence_length}."
)
positive_idx = [i % sequence_length for i in idx]
if len(positive_idx) != len(set(positive_idx)):
raise ValueError(
f"Found duplicate values in `idx`, got {idx} when "
f"`sequence_length` is {sequence_length}."
)


def split(
a: Tuple[Any, ...],
idx: Tuple[int, ...],
) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:
"""Splits the sequence `a` into two sequences."""
validate_idx_for_sequence_len(idx, len(a))
return (
tuple([a[i] for i in idx]),
tuple([a[i] for i in range(len(a)) if i not in idx]),
)


def merge(
a: Sequence[Any],
b: Sequence[Any],
idx: Sequence[int],
) -> Tuple[Any, ...]:
"""Merges the sequences `a` and `b`, undoing a `_split` operation."""
validate_idx_for_sequence_len(idx, len(a) + len(b))
positive_idx = [i % (len(a) + len(b)) for i in idx]
iter_a = iter(a)
iter_b = iter(b)
return tuple(
[
next(iter_a) if i in positive_idx else next(iter_b)
for i in range(len(a) + len(b))
]
)


def to_jax(tree: PyTree) -> PyTree:
"""Converts leaves of a pytree to jax arrays."""
return jax.tree_util.tree_map(jnp.asarray, tree)


def to_numpy(tree: PyTree) -> PyTree:
"""Converts leaves of a pytree to numpy arrays."""
return jax.tree_util.tree_map(onp.asarray, tree)


def arraybox_to_numpy(tree: PyTree) -> PyTree:
"""Converts `ArrayBox` leaves of a pytree to numpy arrays."""
return jax.tree_util.tree_map(
lambda x: x._value if isinstance(x, npa.numpy_boxes.ArrayBox) else x,
tree,
)


def ensure_tuple(xs: Any) -> Tuple[Any, bool]:
"""Returns `(xs, True)` if `xs` is a tuple, and `((xs,), False)` otherwise."""
is_tuple = isinstance(xs, tuple)
return (xs if is_tuple else (xs,)), is_tuple
Loading

0 comments on commit a6bc34b

Please sign in to comment.