diff --git a/README.md b/README.md index e798a9e..e2bb7b3 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,8 @@ pip install --upgrade pip conda create -n bpd python=3.12 conda activate bpd -# Install JAX -pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || pip install -U "jax[cpu]" +# Install JAX (on GPU) +pip install -U "jax[cuda12]" # Install JAX-Galsim pip install git+https://github.com/GalSim-developers/JAX-GalSim.git diff --git a/bpd/pipelines/image_samples.py b/bpd/pipelines/image_samples.py index 78ad7c7..56f3aae 100644 --- a/bpd/pipelines/image_samples.py +++ b/bpd/pipelines/image_samples.py @@ -9,7 +9,11 @@ from bpd.chains import run_inference_nuts from bpd.draw import draw_gaussian_galsim from bpd.noise import add_noise -from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation +from bpd.prior import ( + ellip_prior_e1e2, + sample_ellip_prior, + scalar_shear_transformation, +) def sample_target_galaxy_params_simple( @@ -63,8 +67,8 @@ def logprior( prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x) prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x) - e_mag = jnp.sqrt(params["e1"] ** 2 + params["e2"] ** 2) - prior += jnp.log(ellip_mag_prior(e_mag, sigma=sigma_e)) + e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1) + prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e)) return prior diff --git a/bpd/prior.py b/bpd/prior.py index c0c4873..4b438a6 100644 --- a/bpd/prior.py +++ b/bpd/prior.py @@ -23,9 +23,9 @@ def ellip_mag_prior(e_mag: ArrayLike, sigma: float) -> ArrayLike: return (1 - e_mag**2) ** 2 * e_mag * jnp.exp(-(e_mag**2) / (2 * sigma**2)) / _norm -def ellip_prior_e1e2(e: Array, sigma: float) -> ArrayLike: +def ellip_prior_e1e2(e1e2: Array, sigma: float) -> ArrayLike: """Prior on e1, e2 using Gary's prior for magnitude. Includes Jacobian factor: `|e|`""" - e_mag = norm(e, axis=-1) + e_mag = norm(e1e2, axis=-1) _norm1 = ( -4 * sigma**4 + sigma**2 + 8 * sigma**6 * (1 - jnp.exp(-1 / (2 * sigma**2))) diff --git a/experiments/exp3/figs/contours.pdf b/experiments/exp3/figs/contours.pdf index 3ae83a7..4ccc05f 100644 Binary files a/experiments/exp3/figs/contours.pdf and b/experiments/exp3/figs/contours.pdf differ diff --git a/experiments/exp3/figs/hists.pdf b/experiments/exp3/figs/hists.pdf index 1932cdb..e88f185 100644 Binary files a/experiments/exp3/figs/hists.pdf and b/experiments/exp3/figs/hists.pdf differ diff --git a/experiments/exp3/figs/scatter_shapes.pdf b/experiments/exp3/figs/scatter_shapes.pdf index 2806728..6889f59 100644 Binary files a/experiments/exp3/figs/scatter_shapes.pdf and b/experiments/exp3/figs/scatter_shapes.pdf differ diff --git a/experiments/exp3/figs/traces.pdf b/experiments/exp3/figs/traces.pdf index 93b4840..84dc8cc 100644 Binary files a/experiments/exp3/figs/traces.pdf and b/experiments/exp3/figs/traces.pdf differ diff --git a/experiments/exp3/get_posteriors.sh b/experiments/exp3/get_posteriors.sh index 9795ca2..2728b0c 100755 --- a/experiments/exp3/get_posteriors.sh +++ b/experiments/exp3/get_posteriors.sh @@ -3,5 +3,5 @@ export CUDA_VISIBLE_DEVICES="0" export JAX_ENABLE_X64="True" SEED="43" -# ./get_image_interim_samples_fixed.py $SEED +./get_image_interim_samples_fixed.py $SEED ../../scripts/get_shear_from_interim_samples.py $SEED test_fixed_shear_inference_images_$SEED "e_post_${SEED}.npz" --overwrite