Skip to content

Commit

Permalink
Adjust test tolerances for partitionable threefry
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 14, 2025
1 parent b36f6c2 commit 48a9c90
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions optax/_src/linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ def test_linesearch(self, problem_name: str, seed: int):
with self.subTest('Check against scipy'):
stepsize = otu.tree_get(final_state, 'learning_rate')
final_value = otu.tree_get(final_state, 'value')
chex.assert_trees_all_close(scipy_res[0], stepsize, rtol=1e-5)
chex.assert_trees_all_close(scipy_res[3], final_value, rtol=1e-5)
chex.assert_trees_all_close(scipy_res[0], stepsize, rtol=1e-4)
chex.assert_trees_all_close(scipy_res[3], final_value, rtol=1e-4)

def test_failure_descent_direction(self):
"""Check failure when updates are not a descent direction."""
Expand Down
2 changes: 1 addition & 1 deletion optax/perturbations/_make_pert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def exact_loss(inputs):
chex.assert_trees_all_close(expected_grad, got_grad, atol=2e-2)
expected_dict = pert_argmax_fun(self.tree_a_dict_jax, self.rng_jax)
got_dict = jtu.tree_map(softmax_fun, self.tree_a_dict_jax)
chex.assert_trees_all_close(expected_dict, got_dict, atol=2e-2)
chex.assert_trees_all_close(expected_dict, got_dict, atol=3e-2)

def test_values_on_tree(self):
"""Test that the perturbations are well applied for functions on trees.
Expand Down

0 comments on commit 48a9c90

Please sign in to comment.