Skip to content

Commit

Permalink
Correct GB prior implementation (#62)
Browse files Browse the repository at this point in the history
* fix GB prior and remove clipped ellipticity sampling until fixed in issue later

* docstring

* refactor to allow more parameters in prior

* import remove

* update test with changes

* update pipeline

* update function

* use new prior

* update test

* I like dummy variables

* test

* jit is fine

* import jit

* too strict

* slight modification in docstring

* prior should already include sigma_e_int

* propagate change

* propgate other changes to scripts for experiments

* new figures for exp1

* update scripts for exp3

* figures and updates on exp3
  • Loading branch information
ismael-mendoza authored Dec 12, 2024
1 parent cf26964 commit 784d39b
Show file tree
Hide file tree
Showing 30 changed files with 484 additions and 265 deletions.
57 changes: 29 additions & 28 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,50 @@
from functools import partial
from typing import Callable

import jax.numpy as jnp
import jax.scipy as jsp
from jax import Array, grad, vmap
from jax.numpy.linalg import norm
from jax.typing import ArrayLike

from bpd.prior import inv_shear_func1, inv_shear_func2, inv_shear_transformation
from bpd.prior import (
ellip_prior_e1e2,
inv_shear_func1,
inv_shear_func2,
inv_shear_transformation,
)

_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None))
_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None))
_inv_shear_trans = vmap(inv_shear_transformation, in_axes=(0, None))


def shear_loglikelihood(
g: Array,
sigma_e: float,
e_post: Array,
*,
prior: Callable,
interim_prior: Callable, # fixed
) -> ArrayLike:
# Given by the inference procedure in Schneider et al. 2014
# assume single shear g
# assume e_obs.shape == (N, K, 2) where N is number of galaxies, K is samples per galaxy
# the priors are callables for now on only ellipticities
# the interim_prior should have been used when obtaining e_obs from the chain (i.e. for now same sigma)
_, _, _ = e_post.shape # (N, K, 2)
_prior = partial(prior, sigma=sigma_e)

e_post_mag = norm(e_post, axis=-1)
denom = interim_prior(e_post_mag) # (N, K), can ignore angle in prior as uniform
def true_ellip_logprior(e_post: Array, g: Array, *, sigma_e: float):
"""Implementation of GB's true prior on interim posterior samples of ellipticities."""

# for num, use trick
# p(w_n' | g, alpha ) = p(w_n' \cross^{-1} g | alpha ) = p(w_n | alpha) * |jac(w_n / w_n')|

# shape = (N, K, 2)
# jacobian of inverse shear transformation
grad1 = _grad_fnc1(e_post, g)
grad2 = _grad_fnc2(e_post, g)
absjacdet = jnp.abs(grad1[..., 0] * grad2[..., 1] - grad1[..., 1] * grad2[..., 0])

# true prior on unsheared ellipticity
e_post_unsheared = _inv_shear_trans(e_post, g)
e_post_unsheared_mag = norm(e_post_unsheared, axis=-1)
num = _prior(e_post_unsheared_mag) * absjacdet # (N, K)
prior_val = ellip_prior_e1e2(e_post_unsheared, sigma=sigma_e)

return jnp.log(prior_val) + jnp.log(absjacdet)


def shear_loglikelihood(
g: Array,
post_params: dict[str, Array],
*,
logprior: Callable,
interim_logprior: Callable, # fixed
) -> ArrayLike:
"""Shear Likelihood implementation of Schneider et al. 2014."""
e_post = post_params["e1e2"]
_, _, _ = e_post.shape # (N, K, 2)

denom = interim_logprior(post_params)
num = logprior(post_params, g)

ratio = jsp.special.logsumexp(jnp.log(num) - jnp.log(denom), axis=-1)
ratio = jsp.special.logsumexp(num - denom, axis=-1)
return ratio.sum()
2 changes: 1 addition & 1 deletion bpd/measure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax.numpy as jnp
from jaxtyping import ArrayLike
from jax.typing import ArrayLike


def get_snr(im: ArrayLike, background: float) -> float:
Expand Down
9 changes: 3 additions & 6 deletions bpd/pipelines/image_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import Callable

import jax.numpy as jnp
from jax import Array, random
from jax import jit as jjit
from jax import Array, jit, random
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

Expand Down Expand Up @@ -158,7 +157,6 @@ def pipeline_interim_samples_one_galaxy(
initialization_fnc: Callable,
draw_fnc: Callable,
logprior: Callable,
sigma_e_int: float,
n_samples: int = 100,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
Expand All @@ -175,10 +173,9 @@ def pipeline_interim_samples_one_galaxy(
_loglikelihood = partial(
loglikelihood, draw_fnc=_draw_fnc, background=background, free_flux=free_flux
)
_logprior = partial(logprior, sigma_e=sigma_e_int)

_logtarget = partial(
logtarget, logprior_fnc=_logprior, loglikelihood_fnc=_loglikelihood
logtarget, logprior_fnc=logprior, loglikelihood_fnc=_loglikelihood
)

_inference_fnc = partial(
Expand All @@ -190,7 +187,7 @@ def pipeline_interim_samples_one_galaxy(
initial_step_size=initial_step_size,
n_samples=n_samples,
)
_run_inference = jjit(_inference_fnc)
_run_inference = jit(_inference_fnc)

interim_samples = _run_inference(k2, init_position, target_image)
return interim_samples
48 changes: 30 additions & 18 deletions bpd/pipelines/shear_inference.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,69 @@
from functools import partial
from typing import Callable

from jax import Array
from jax import jit as jjit
import jax.numpy as jnp
from jax import Array, jit
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd.chains import run_inference_nuts
from bpd.likelihood import shear_loglikelihood
from bpd.prior import ellip_mag_prior
from bpd.likelihood import shear_loglikelihood, true_ellip_logprior
from bpd.prior import ellip_prior_e1e2


def logtarget_density(g: Array, *, data: Array, loglikelihood: Callable):
loglike = loglikelihood(g, e_post=data)
logprior = stats.uniform.logpdf(g, -0.1, 0.2).sum()
def logtarget_density(
g: Array, *, data: Array, loglikelihood: Callable, sigma_g: float = 0.01
):
loglike = loglikelihood(g, post_params=data)
logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum()
return logprior + loglike


def pipeline_shear_inference(
def _logprior(post_params: dict[str, Array], g: Array, *, sigma_e: float):
e_post = post_params["e1e2"]
return true_ellip_logprior(e_post, g, sigma_e=sigma_e)


def _interim_logprior(post_params: dict[str, Array], sigma_e_int: float):
e_post = post_params["e1e2"]
return jnp.log(ellip_prior_e1e2(e_post, sigma=sigma_e_int))


def pipeline_shear_inference_ellipticities(
rng_key: PRNGKeyArray,
e_post: Array,
init_g: Array,
*,
true_g: Array,
sigma_e: float,
sigma_e_int: float,
n_samples: int,
initial_step_size: float,
sigma_g: float = 0.01,
n_warmup_steps: int = 500,
max_num_doublings: int = 2,
):
interim_prior = partial(ellip_mag_prior, sigma=sigma_e_int)

# NOTE: jit must be applied without `e_post` in partial!
_loglikelihood = jjit(
_loglikelihood = jit(
partial(
shear_loglikelihood,
sigma_e=sigma_e,
prior=ellip_mag_prior,
interim_prior=interim_prior,
logprior=partial(_logprior, sigma_e=sigma_e),
interim_logprior=partial(_interim_logprior, sigma_e_int=sigma_e_int),
)
)
_logtarget = partial(logtarget_density, loglikelihood=_loglikelihood)
_logtarget = partial(
logtarget_density, loglikelihood=_loglikelihood, sigma_g=sigma_g
)

_do_inference = partial(
run_inference_nuts,
data=e_post,
data={"e1e2": e_post},
logtarget=_logtarget,
n_samples=n_samples,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
)

g_samples = _do_inference(rng_key, true_g)
g_samples = _do_inference(rng_key, init_g)

return g_samples
26 changes: 11 additions & 15 deletions bpd/pipelines/toy_ellips.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
from functools import partial
from typing import Callable

import jax.numpy as jnp
import jax.scipy as jsp
from jax import Array, random, vmap
from jax import jit as jjit
from jax import Array, jit, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.numpy.linalg import norm

from bpd.chains import run_inference_nuts
from bpd.prior import ellip_mag_prior, sample_synthetic_sheared_ellips_unclipped
from bpd.prior import (
ellip_prior_e1e2,
sample_noisy_ellipticities_unclipped,
)


def logtarget(
e_sheared: Array,
*,
data: Array, # renamed from `e_obs` for comptability with `do_inference_nuts`
sigma_m: float,
interim_prior: Callable,
sigma_e_int: float,
):
e_obs = data
assert e_sheared.shape == (2,) and e_obs.shape == (2,)

# ignore angle prior assumed uniform
# prior enforces magnitude < 1.0 for posterior samples
prior = jnp.log(interim_prior(norm(e_sheared)))
prior = jnp.log(ellip_prior_e1e2(e_sheared, sigma=sigma_e_int))
likelihood = jnp.sum(jsp.stats.norm.logpdf(e_obs, loc=e_sheared, scale=sigma_m))
return prior + likelihood

Expand All @@ -46,16 +44,14 @@ def pipeline_toy_ellips_samples(

true_g = jnp.array([g1, g2])

e_obs, e_sheared, _ = sample_synthetic_sheared_ellips_unclipped(
k1, true_g, n=n_gals, sigma_m=sigma_m, sigma_e=sigma_e
e_obs, e_sheared, _ = sample_noisy_ellipticities_unclipped(
k1, g=true_g, sigma_m=sigma_m, sigma_e=sigma_e, n=n_gals
)

interim_prior = partial(ellip_mag_prior, sigma=sigma_e_int)

_logtarget = partial(logtarget, sigma_m=sigma_m, interim_prior=interim_prior)
_logtarget = partial(logtarget, sigma_m=sigma_m, sigma_e_int=sigma_e_int)

keys2 = random.split(k2, n_gals)
_do_inference_jitted = jjit(
_do_inference_jitted = jit(
partial(
run_inference_nuts,
logtarget=_logtarget,
Expand Down
Loading

0 comments on commit 784d39b

Please sign in to comment.