From e903e9a8a5f307eff0352908fc50785a477ebf12 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 22 Oct 2024 13:23:20 -0700 Subject: [PATCH] test shear transform on images too --- tests/test_shear_trans.py | 70 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/test_shear_trans.py diff --git a/tests/test_shear_trans.py b/tests/test_shear_trans.py new file mode 100644 index 0000000..a10b020 --- /dev/null +++ b/tests/test_shear_trans.py @@ -0,0 +1,70 @@ +import numpy as np +import pytest +from jax import jit as jjit +from jax import random + +from bpd.draw import draw_gaussian +from bpd.prior import ( + inv_shear_transformation, + sample_ellip_prior, + scalar_inv_shear_transformation, + scalar_shear_transformation, + shear_transformation, +) + + +def test_scalar_inverse(): + + # scalar version + ellips = (0.0, 0.1, 0.2, -0.1, -0.2) + shears = (0.0, -0.01, 0.01, -0.02, 0.02) + for e1 in ellips: + for e2 in ellips: + for g1 in shears: + for g2 in shears: + e_trans = scalar_shear_transformation((e1, e2), (g1, g2)) + e1_new, e2_new = scalar_inv_shear_transformation(e_trans, (g1, g2)) + + e_array = np.array([e1, e2]) + e_new_array = np.array([e1_new, e2_new]) + np.testing.assert_allclose(e_new_array, e_array, atol=1e-15) + + +@pytest.mark.parametrize("seed", [1234, 4567]) +def test_transformation(seed): + shears = (0.0, -0.01, 0.01, -0.02, 0.02) + + k = random.key(seed) + e_samples = sample_ellip_prior(k, sigma=0.3, n=100) + assert e_samples.shape == (100, 2) + + for g1 in shears: + for g2 in shears: + e_trans_samples = shear_transformation(e_samples, (g1, g2)) + e_new = inv_shear_transformation(e_trans_samples, (g1, g2)) + assert e_new.shape == (100, 2) + np.testing.assert_allclose(e_new, e_samples) + + +def test_image_shear_commute(): + """Test that the shear operation on galsim commutes with the analytical shear transformation.""" + ellips = (0.0, 0.1, 0.2, -0.1, -0.2) + shears = (0.0, -0.01, 0.01, -0.02, 0.02) + f = 1e3 + hlr = 0.9 + x, y = (1, 1) + + draw_jitted = jjit(draw_gaussian) + for e1 in ellips: + for e2 in ellips: + for g1 in shears: + for g2 in shears: + (e1_p, e2_p) = scalar_shear_transformation((e1, e2), (g1, g2)) + im1 = draw_jitted( + f=f, hlr=hlr, e1=e1, e2=e2, g1=g1, g2=g2, x=x, y=y + ) + im2 = draw_jitted( + f=f, hlr=hlr, e1=e1_p, e2=e2_p, g1=0.0, g2=0.0, x=x, y=y + ) + + np.testing.assert_allclose(im1, im2, rtol=1e-6, atol=1e-10)