diff --git a/bpd/pipelines/image_samples.py b/bpd/pipelines/image_samples.py index f0750f0..0b6ad32 100644 --- a/bpd/pipelines/image_samples.py +++ b/bpd/pipelines/image_samples.py @@ -10,7 +10,34 @@ from bpd.chains import run_inference_nuts from bpd.draw import draw_gaussian_galsim from bpd.noise import add_noise -from bpd.prior import ellip_mag_prior, scalar_shear_transformation +from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation + + +def sample_target_galaxy_params_simple( + rng_key: PRNGKeyArray, + *, + shape_noise: float, + g1: float = 0.02, + g2: float = 0.0, +): + """Fix parameters except position and ellipticity, which come from a prior. + + * The position is drawn uniformly within a pixel (dither). + * The ellipticity is drawn from Gary's prior given the shape noise. + + """ + dkey, ekey = random.split(rng_key, 2) + + x, y = random.uniform(dkey, shape=(2,), minval=-0.5, maxval=0.5) + e = sample_ellip_prior(ekey, sigma=shape_noise, n=1) + return { + "e1": e[0, 0], + "e2": e[0, 1], + "x": x, + "y": y, + "g1": g1, + "g2": g2, + } # interim prior @@ -21,23 +48,25 @@ def logprior( sigma_x: float = 0.5, # pixels flux_bds: tuple = (-1.0, 9.0), hlr_bds: tuple = (0.01, 5.0), + all_free: bool = True, ) -> Array: prior = jnp.array(0.0) - f1, f2 = flux_bds - prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1) + if all_free: + f1, f2 = flux_bds + prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1) - h1, h2 = hlr_bds - prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1) + h1, h2 = hlr_bds + prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1) + + # NOTE: hard-coded assumption that galaxy is in center-pixel within odd-size image. + # sigma_x in units of pixels. + prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x) + prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x) e_mag = jnp.sqrt(params["e1"] ** 2 + params["e2"] ** 2) prior += jnp.log(ellip_mag_prior(e_mag, sigma=sigma_e)) - # NOTE: hard-coded assumption that galaxy is in center-pixel within odd-size image. - # sigma_x in units of pixels. - prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x) - prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x) - return prior @@ -47,12 +76,12 @@ def loglikelihood( *, draw_fnc: Callable, background: float, - free_f: bool = True, + free_flux: bool = True, ): # NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments. _draw_params = {**{"g1": 0.0, "g2": 0.0}, **params} # function is more general - if free_f: + if free_flux: _draw_params["f"] = 10 ** _draw_params.pop("lf") model = draw_fnc(**_draw_params) diff --git a/bpd/pipelines/image_samples_fixed_flux.py b/bpd/pipelines/image_samples_fixed_flux.py deleted file mode 100644 index 4b3bf9a..0000000 --- a/bpd/pipelines/image_samples_fixed_flux.py +++ /dev/null @@ -1,49 +0,0 @@ -import jax.numpy as jnp -from jax import Array, random -from jax._src.prng import PRNGKeyArray -from jax.scipy import stats - -from bpd.prior import ellip_mag_prior, sample_ellip_prior - - -def get_target_galaxy_params_simple( - rng_key: PRNGKeyArray, - *, - shape_noise: float, - g1: float = 0.02, - g2: float = 0.0, -): - """Fix parameters except position and ellipticity, which come from a prior. - - * The position is drawn uniformly within a pixel (dither). - * The ellipticity is drawn from Gary's prior given the shape noise. - - """ - dkey, ekey = random.split(rng_key, 2) - - x, y = random.uniform(dkey, shape=(2,), minval=-0.5, maxval=0.5) - e = sample_ellip_prior(ekey, sigma=shape_noise, n=1) - return { - "e1": e[0, 0], - "e2": e[0, 1], - "x": x, - "y": y, - "g1": g1, - "g2": g2, - } - - -def logprior( - params: dict[str, Array], *, sigma_e: float, sigma_x: float = 0.5 -) -> Array: - prior = jnp.array(0.0) - - e_mag = jnp.sqrt(params["e1"] ** 2 + params["e2"] ** 2) - prior += jnp.log(ellip_mag_prior(e_mag, sigma=sigma_e)) - - # NOTE: hard-coded assumption that galaxy is in center-pixel within odd-size image. - # sigma_x in units of pixels. - prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x) - prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x) - - return prior