From 6a4b817b57819e5ace0ac3dcaa46fe32a3e91778 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 31 Dec 2024 08:16:20 -0800 Subject: [PATCH] fix exp2 --- .../exp2/run_inference_galaxy_images.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/experiments/exp2/run_inference_galaxy_images.py b/experiments/exp2/run_inference_galaxy_images.py index c260a83..b731bdf 100755 --- a/experiments/exp2/run_inference_galaxy_images.py +++ b/experiments/exp2/run_inference_galaxy_images.py @@ -6,8 +6,9 @@ import jax.numpy as jnp import typer -from jax import jit, random, vmap +from jax import Array, jit, random, vmap from jax._src.prng import PRNGKeyArray +from jax.scipy import stats from bpd import DATA_DIR from bpd.chains import run_sampling_nuts, run_warmup_nuts @@ -17,10 +18,35 @@ get_target_images, get_true_params_from_galaxy_params, loglikelihood, - logprior, logtarget, sample_target_galaxy_params_simple, ) +from bpd.prior import ellip_prior_e1e2 + + +def logprior( + params: dict[str, Array], + *, + sigma_e: float, + sigma_x: float = 0.5, # pixels + flux_bds: tuple = (-1.0, 9.0), + hlr_bds: tuple = (0.01, 5.0), +) -> Array: + prior = jnp.array(0.0) + + f1, f2 = flux_bds + prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1) + + h1, h2 = hlr_bds + prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1) + + prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x) + prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x) + + e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1) + prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e)) + + return prior def sample_prior(