Skip to content

Commit

Permalink
Merge pull request #16 from mfschubert/tests
Browse files Browse the repository at this point in the history
Additional test cases
  • Loading branch information
mfschubert authored Sep 13, 2023
2 parents ee56a7c + f20d251 commit f0fd6e1
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,43 @@


TEST_FNS_AND_ARGS = (
( # Basic scalar-valued function.
( # Basic scalar-valued function, real outputs.
lambda x: x**2,
(3.0,),
),
( # Scalar-valued function with two input arguments.
( # Scalar-valued function with two input arguments, real outputs.
lambda x, y: x**2 + y,
(3.0, 4.0),
),
( # Two arguments, scalar output wrapped in a tuple.
( # Two arguments, scalar output wrapped in a tuple, real outputs.
lambda x, y: (x**2 + y,),
(3.0, 4.0),
),
( # Two arguments, two outputs.
( # Two arguments, two outputs, real outputs.
lambda x, y: (x**2 + y, x - y),
(3.0, 4.0),
),
( # Two arguments, two outputs, complex.
( # Two arguments, two outputs, real arguments, complex outputs.
lambda x, y: (x**2 + 1j * y**2, 1j * x**2 - y**2),
(3.0, 4.0),
),
( # Two arguments, two outputs, complex arguments, real outputs.
lambda x, y: (npa.abs(x) ** 2 + npa.abs(y), npa.abs(x + y)),
(3.0 + 1.0j, 4.0 + 0.5j),
),
( # Two arguments, two outputs, complex outputs.
lambda x, y: (x**2 + y, x - y),
(3.0 + 1.0j, 4.0 + 0.5j),
),
( # Two arguments, two outputs, complex.
( # Two arguments, two outputs, complex outputs.
lambda x, y, z: (x**2 + y + z, x - y),
(3.0 + 1.0j, 4.0 + 0.5j, -11.0),
),
( # Returns a pytree.
( # Returns a pytree, complex outputs.
lambda x, y: {"a": x**2 + y, "b": (x - y, y - x)},
(3.0 + 1.0j, 4.0 + 0.5j),
),
( # Arguments and outputs include pytree.
( # Arguments and outputs include pytree, complex outputs.
lambda x, y: {
"a": (x["a0"] + x["a1"]) ** 2 + y,
"b": (x["a0"] - y, y - x["a1"]),
Expand Down Expand Up @@ -112,7 +120,11 @@ def test_wrapped_matches_autograd(self, autograd_fn, args):
nondiff_argnums=(),
nondiff_outputnums=(),
)
onp.testing.assert_array_equal(expected_outputs, wrapped(*args))
for v, ev in zip(
jax.tree_util.tree_leaves(wrapped(*args)),
jax.tree_util.tree_leaves(expected_outputs),
):
onp.testing.assert_allclose(v, ev)

def autograd_scalar_fn(*args):
outputs = autograd_fn(*args)
Expand Down

0 comments on commit f0fd6e1

Please sign in to comment.