From 7444a5d39b5558117159294839c2058861f41bed Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 14 Jan 2025 13:16:42 -0800 Subject: [PATCH] fix another test --- tests/test_convergence.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index 12d0191..e9747d4 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -8,10 +8,12 @@ from jax import jit, random, vmap from bpd.chains import run_inference_nuts -from bpd.pipelines.shear_inference import pipeline_shear_inference_ellipticities -from bpd.pipelines.toy_ellips import logtarget as logtarget_toy_ellips -from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples -from bpd.prior import sample_noisy_ellipticities_unclipped +from bpd.pipelines import ( + logtarget_toy_ellips, + pipeline_shear_inference_simple, + pipeline_toy_ellips, +) +from bpd.sample import sample_noisy_ellipticities_unclipped @pytest.mark.parametrize("seed", [1234, 4567]) @@ -84,7 +86,7 @@ def test_toy_shear_convergence(seed): key = random.key(seed) k1, k2 = random.split(key) - e_post, _, _ = pipeline_toy_ellips_samples( + e_post, _, _ = pipeline_toy_ellips( k1, g1=g1, g2=g2, @@ -97,7 +99,7 @@ def test_toy_shear_convergence(seed): # run 4 shear chains over the given e_post _pipeline_shear1 = partial( - pipeline_shear_inference_ellipticities, + pipeline_shear_inference_simple, init_g=true_g, sigma_e=sigma_e, sigma_e_int=sigma_e_int,