Skip to content

Commit

Permalink
Adjust test tolerances for partitionable threefry
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 15, 2025
1 parent b36f6c2 commit 4e577e5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions optax/_src/linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def test_linesearch_with_jax_variants(self):
value = otu.tree_get(state, 'value')
self.assertFalse(jnp.isinf(value))

@absltest.skip('TODO(rdyro): need to match scipy linesearch algorithm')
@parameterized.product(
problem_name=[
'polynomial',
Expand Down
4 changes: 2 additions & 2 deletions optax/perturbations/_make_pert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def exact_loss(inputs):
expect_hessian = jax.hessian(pert_loss)(self.array_small_jax, self.rng_jax)
got_hessian = jax.hessian(exact_loss)(self.array_small_jax)
chex.assert_trees_all_equal_shapes(expect_hessian, got_hessian)
chex.assert_trees_all_close(expected_grad, got_grad, atol=2e-2)
chex.assert_trees_all_close(expected_grad, got_grad, atol=6e-2)
expected_dict = pert_argmax_fun(self.tree_a_dict_jax, self.rng_jax)
got_dict = jtu.tree_map(softmax_fun, self.tree_a_dict_jax)
chex.assert_trees_all_close(expected_dict, got_dict, atol=2e-2)
chex.assert_trees_all_close(expected_dict, got_dict, atol=6e-2)

def test_values_on_tree(self):
"""Test that the perturbations are well applied for functions on trees.
Expand Down

0 comments on commit 4e577e5

Please sign in to comment.