Skip to content

Commit

Permalink
fix logprior used in ellipticities for exp3 (#65)
Browse files Browse the repository at this point in the history
* fix logprior here

* better naming

* correction

* fix

* fix readme

* corrected figures
  • Loading branch information
ismael-mendoza authored Dec 30, 2024
1 parent 784d39b commit bf46bb9
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ pip install --upgrade pip
conda create -n bpd python=3.12
conda activate bpd

# Install JAX
pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || pip install -U "jax[cpu]"
# Install JAX (on GPU)
pip install -U "jax[cuda12]"

# Install JAX-Galsim
pip install git+https://github.com/GalSim-developers/JAX-GalSim.git
Expand Down
10 changes: 7 additions & 3 deletions bpd/pipelines/image_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from bpd.chains import run_inference_nuts
from bpd.draw import draw_gaussian_galsim
from bpd.noise import add_noise
from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation
from bpd.prior import (
ellip_prior_e1e2,
sample_ellip_prior,
scalar_shear_transformation,
)


def sample_target_galaxy_params_simple(
Expand Down Expand Up @@ -63,8 +67,8 @@ def logprior(
prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)

e_mag = jnp.sqrt(params["e1"] ** 2 + params["e2"] ** 2)
prior += jnp.log(ellip_mag_prior(e_mag, sigma=sigma_e))
e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1)
prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e))

return prior

Expand Down
4 changes: 2 additions & 2 deletions bpd/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def ellip_mag_prior(e_mag: ArrayLike, sigma: float) -> ArrayLike:
return (1 - e_mag**2) ** 2 * e_mag * jnp.exp(-(e_mag**2) / (2 * sigma**2)) / _norm


def ellip_prior_e1e2(e: Array, sigma: float) -> ArrayLike:
def ellip_prior_e1e2(e1e2: Array, sigma: float) -> ArrayLike:
"""Prior on e1, e2 using Gary's prior for magnitude. Includes Jacobian factor: `|e|`"""
e_mag = norm(e, axis=-1)
e_mag = norm(e1e2, axis=-1)

_norm1 = (
-4 * sigma**4 + sigma**2 + 8 * sigma**6 * (1 - jnp.exp(-1 / (2 * sigma**2)))
Expand Down
Binary file modified experiments/exp3/figs/contours.pdf
Binary file not shown.
Binary file modified experiments/exp3/figs/hists.pdf
Binary file not shown.
Binary file modified experiments/exp3/figs/scatter_shapes.pdf
Binary file not shown.
Binary file modified experiments/exp3/figs/traces.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion experiments/exp3/get_posteriors.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ export CUDA_VISIBLE_DEVICES="0"
export JAX_ENABLE_X64="True"
SEED="43"

# ./get_image_interim_samples_fixed.py $SEED
./get_image_interim_samples_fixed.py $SEED
../../scripts/get_shear_from_interim_samples.py $SEED test_fixed_shear_inference_images_$SEED "e_post_${SEED}.npz" --overwrite

0 comments on commit bf46bb9

Please sign in to comment.