Skip to content

Commit

Permalink
add tests on shear transformation for Gaussian images (#29)
Browse files Browse the repository at this point in the history
* useful for now to track

* general func to  draw gaussian

* typing correction

* test shear transform on images too
  • Loading branch information
ismael-mendoza authored Oct 22, 2024
1 parent 135539e commit f453d13
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 5 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ cython_debug/
# personal
experiments/samples/*.hdf5
**/slurm-*.out
**/bash/*.sh
**/bash/**/*.sh
**/bash/old/*.sh
**/logs/**
**log**.txt
**/jobs_out/**
27 changes: 27 additions & 0 deletions bpd/draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import jax_galsim as xgalsim

GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256)


def draw_gaussian(
f: float,
hlr: float,
e1: float,
e2: float,
g1: float,
g2: float,
x: float,
y: float,
pixel_scale: float = 0.2,
slen: int = 53,
psf_hlr: float = 0.7,
):
# x, y arguments in pixels
gal = xgalsim.Gaussian(flux=f, half_light_radius=hlr)
gal = gal.shear(g1=e1, g2=e2)
gal = gal.shear(g1=g1, g2=g2)

psf = xgalsim.Gaussian(flux=1, half_light_radius=psf_hlr)
gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS)
image = gal_conv.drawImage(nx=slen, ny=slen, scale=pixel_scale, offset=(x, y))
return image.array
6 changes: 3 additions & 3 deletions bpd/prior.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax.numpy as jnp
from jax import random
from jax import Array, random


def ellip_mag_prior(e, sigma: float):
Expand Down Expand Up @@ -68,7 +68,7 @@ def scalar_inv_shear_transformation(e: tuple[float, float], g: tuple[float, floa
inv_shear_func2 = lambda e, g: scalar_inv_shear_transformation(e, g)[1]


def shear_transformation(e, g: tuple[float, float]):
def shear_transformation(e: Array, g: tuple[float, float]):
"""Transform elliptiticies by a fixed shear.
The transformation we used is equation 3.4b in Seitz & Schneider (1997).
Expand All @@ -83,7 +83,7 @@ def shear_transformation(e, g: tuple[float, float]):
return jnp.stack([e_prime.real, e_prime.imag], axis=-1)


def inv_shear_transformation(e, g: tuple[float, float]):
def inv_shear_transformation(e: Array, g: tuple[float, float]):
"""Same as above but the inverse."""
e1, e2 = e[..., 0], e[..., 1]
g1, g2 = g
Expand Down
24 changes: 24 additions & 0 deletions scripts/bash/gpu_job_parallel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
#SBATCH --account=m1727
#SBATCH -C gpu
#SBATCH -N 1
#SBATCH -t 00:20:00
#SBATCH --ntasks-per-node=4
#SBATCH --mail-type=begin,end,fail
#SBATCH [email protected]

#ref: https://docs.nersc.gov/systems/perlmutter/running-jobs/#single-gpu-tasks-in-parallel
K=1000
TRIM=10
N=1_000
BASE_SEED=61
TAG="gpu1_n10000_test"

for i in $(seq 1 4);
do
SEED="${BASE_SEED}${i}"
CMD="python /global/u2/i/imendoza/BPD/scripts/vect_toy_shear_gpu.py --n-samples-gals ${N} --n-samples-shear 3000 --n-vec 50 --seed ${SEED} --n-seeds 250 --tag ${TAG} --k ${K} --trim ${TRIM} --sigma-e-int 2e-3"
srun --exact -u -n 1 --gpus-per-task 1 -c 1 --mem-per-gpu=20G $CMD &
done

wait
70 changes: 70 additions & 0 deletions tests/test_shear_trans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import pytest
from jax import jit as jjit
from jax import random

from bpd.draw import draw_gaussian
from bpd.prior import (
inv_shear_transformation,
sample_ellip_prior,
scalar_inv_shear_transformation,
scalar_shear_transformation,
shear_transformation,
)


def test_scalar_inverse():

# scalar version
ellips = (0.0, 0.1, 0.2, -0.1, -0.2)
shears = (0.0, -0.01, 0.01, -0.02, 0.02)
for e1 in ellips:
for e2 in ellips:
for g1 in shears:
for g2 in shears:
e_trans = scalar_shear_transformation((e1, e2), (g1, g2))
e1_new, e2_new = scalar_inv_shear_transformation(e_trans, (g1, g2))

e_array = np.array([e1, e2])
e_new_array = np.array([e1_new, e2_new])
np.testing.assert_allclose(e_new_array, e_array, atol=1e-15)


@pytest.mark.parametrize("seed", [1234, 4567])
def test_transformation(seed):
shears = (0.0, -0.01, 0.01, -0.02, 0.02)

k = random.key(seed)
e_samples = sample_ellip_prior(k, sigma=0.3, n=100)
assert e_samples.shape == (100, 2)

for g1 in shears:
for g2 in shears:
e_trans_samples = shear_transformation(e_samples, (g1, g2))
e_new = inv_shear_transformation(e_trans_samples, (g1, g2))
assert e_new.shape == (100, 2)
np.testing.assert_allclose(e_new, e_samples)


def test_image_shear_commute():
"""Test that the shear operation on galsim commutes with the analytical shear transformation."""
ellips = (0.0, 0.1, 0.2, -0.1, -0.2)
shears = (0.0, -0.01, 0.01, -0.02, 0.02)
f = 1e3
hlr = 0.9
x, y = (1, 1)

draw_jitted = jjit(draw_gaussian)
for e1 in ellips:
for e2 in ellips:
for g1 in shears:
for g2 in shears:
(e1_p, e2_p) = scalar_shear_transformation((e1, e2), (g1, g2))
im1 = draw_jitted(
f=f, hlr=hlr, e1=e1, e2=e2, g1=g1, g2=g2, x=x, y=y
)
im2 = draw_jitted(
f=f, hlr=hlr, e1=e1_p, e2=e2_p, g1=0.0, g2=0.0, x=x, y=y
)

np.testing.assert_allclose(im1, im2, rtol=1e-6, atol=1e-10)

0 comments on commit f453d13

Please sign in to comment.