diff --git a/test/test_message_passing_jax.py b/test/test_message_passing_jax.py index e40c637f..b27be336 100644 --- a/test/test_message_passing_jax.py +++ b/test/test_message_passing_jax.py @@ -143,7 +143,7 @@ def test_fixed_point_iteration_factorized_fullyconnected(self): qs_jax_factorized = fpi_jax_factorized(A, obs, prior, factor_lists, num_iter=16) for f, _ in enumerate(qs_jax): - self.assertTrue(np.allclose(qs_jax[f], qs_jax_factorized[f])) + self.assertTrue(np.allclose(qs_jax[f], qs_jax_factorized[f], atol=1e-6)) def test_fixed_point_iteration_factorized_sparsegraph(self): """