diff --git a/optax/_src/linesearch_test.py b/optax/_src/linesearch_test.py index ab3a2dfd..ac11a692 100644 --- a/optax/_src/linesearch_test.py +++ b/optax/_src/linesearch_test.py @@ -485,8 +485,8 @@ def test_linesearch(self, problem_name: str, seed: int): with self.subTest('Check against scipy'): stepsize = otu.tree_get(final_state, 'learning_rate') final_value = otu.tree_get(final_state, 'value') - chex.assert_trees_all_close(scipy_res[0], stepsize, rtol=1e-5) - chex.assert_trees_all_close(scipy_res[3], final_value, rtol=1e-5) + chex.assert_trees_all_close(scipy_res[0], stepsize, rtol=1e-4) + chex.assert_trees_all_close(scipy_res[3], final_value, rtol=1e-4) def test_failure_descent_direction(self): """Check failure when updates are not a descent direction.""" diff --git a/optax/perturbations/_make_pert_test.py b/optax/perturbations/_make_pert_test.py index 7928fd37..44e2f792 100644 --- a/optax/perturbations/_make_pert_test.py +++ b/optax/perturbations/_make_pert_test.py @@ -114,7 +114,7 @@ def exact_loss(inputs): chex.assert_trees_all_close(expected_grad, got_grad, atol=2e-2) expected_dict = pert_argmax_fun(self.tree_a_dict_jax, self.rng_jax) got_dict = jtu.tree_map(softmax_fun, self.tree_a_dict_jax) - chex.assert_trees_all_close(expected_dict, got_dict, atol=2e-2) + chex.assert_trees_all_close(expected_dict, got_dict, atol=3e-2) def test_values_on_tree(self): """Test that the perturbations are well applied for functions on trees.