Skip to content

Commit

Permalink
fix exp2
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Dec 31, 2024
1 parent ed0f7cc commit 6a4b817
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions experiments/exp2/run_inference_galaxy_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import jax.numpy as jnp
import typer
from jax import jit, random, vmap
from jax import Array, jit, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd import DATA_DIR
from bpd.chains import run_sampling_nuts, run_warmup_nuts
Expand All @@ -17,10 +18,35 @@
get_target_images,
get_true_params_from_galaxy_params,
loglikelihood,
logprior,
logtarget,
sample_target_galaxy_params_simple,
)
from bpd.prior import ellip_prior_e1e2


def logprior(
params: dict[str, Array],
*,
sigma_e: float,
sigma_x: float = 0.5, # pixels
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
) -> Array:
prior = jnp.array(0.0)

f1, f2 = flux_bds
prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1)

h1, h2 = hlr_bds
prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1)

prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)

e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1)
prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e))

return prior


def sample_prior(
Expand Down

0 comments on commit 6a4b817

Please sign in to comment.