From 4c5b2c0687f321c0d3d82fabe1342a5ef49a0fd1 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 14 Oct 2024 09:10:46 -0700 Subject: [PATCH] fix tests --- bpd/pipelines/toy_ellips.py | 2 +- tests/test_shear_inference.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/bpd/pipelines/toy_ellips.py b/bpd/pipelines/toy_ellips.py index 2108dfa..dadedd0 100644 --- a/bpd/pipelines/toy_ellips.py +++ b/bpd/pipelines/toy_ellips.py @@ -88,7 +88,7 @@ def pipeline_toy_ellips_samples( do_inference, sigma_e=sigma_e, sigma_m=sigma_m, - interim_posterior=interim_prior, + interim_prior=interim_prior, k=k, ) ) diff --git a/tests/test_shear_inference.py b/tests/test_shear_inference.py index 89a5b9e..bb5f497 100644 --- a/tests/test_shear_inference.py +++ b/tests/test_shear_inference.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import pytest +from jax import random from scripts.get_shear_from_post_ellips import pipeline_shear_inference from scripts.get_toy_ellip_samples import pipeline_toy_ellips_samples @@ -10,13 +11,16 @@ @pytest.mark.parametrize("seed", [1234, 4567]) def test_shear_inference_toy_ellipticities(seed): + key = random.key(seed) + k1, k2 = random.split(key) + g1 = 0.02 g2 = 0.0 sigma_e = 1e-3 sigma_m = 1e-4 e_post = pipeline_toy_ellips_samples( - seed, + k1, g1=g1, g2=g2, sigma_e=sigma_e, @@ -29,7 +33,7 @@ def test_shear_inference_toy_ellipticities(seed): e_post_trimmed = e_post[:, ::10, :] shear_samples = pipeline_shear_inference( - seed, + k2, e_post_trimmed, jnp.array([g1, g2]), sigma_e=sigma_e,