-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Correct GB prior implementation (#62)
* fix GB prior and remove clipped ellipticity sampling until fixed in issue later * docstring * refactor to allow more parameters in prior * import remove * update test with changes * update pipeline * update function * use new prior * update test * I like dummy variables * test * jit is fine * import jit * too strict * slight modification in docstring * prior should already include sigma_e_int * propagate change * propgate other changes to scripts for experiments * new figures for exp1 * update scripts for exp3 * figures and updates on exp3
- Loading branch information
1 parent
cf26964
commit 784d39b
Showing
30 changed files
with
484 additions
and
265 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,50 @@ | ||
from functools import partial | ||
from typing import Callable | ||
|
||
import jax.numpy as jnp | ||
import jax.scipy as jsp | ||
from jax import Array, grad, vmap | ||
from jax.numpy.linalg import norm | ||
from jax.typing import ArrayLike | ||
|
||
from bpd.prior import inv_shear_func1, inv_shear_func2, inv_shear_transformation | ||
from bpd.prior import ( | ||
ellip_prior_e1e2, | ||
inv_shear_func1, | ||
inv_shear_func2, | ||
inv_shear_transformation, | ||
) | ||
|
||
_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None)) | ||
_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None)) | ||
_inv_shear_trans = vmap(inv_shear_transformation, in_axes=(0, None)) | ||
|
||
|
||
def shear_loglikelihood( | ||
g: Array, | ||
sigma_e: float, | ||
e_post: Array, | ||
*, | ||
prior: Callable, | ||
interim_prior: Callable, # fixed | ||
) -> ArrayLike: | ||
# Given by the inference procedure in Schneider et al. 2014 | ||
# assume single shear g | ||
# assume e_obs.shape == (N, K, 2) where N is number of galaxies, K is samples per galaxy | ||
# the priors are callables for now on only ellipticities | ||
# the interim_prior should have been used when obtaining e_obs from the chain (i.e. for now same sigma) | ||
_, _, _ = e_post.shape # (N, K, 2) | ||
_prior = partial(prior, sigma=sigma_e) | ||
|
||
e_post_mag = norm(e_post, axis=-1) | ||
denom = interim_prior(e_post_mag) # (N, K), can ignore angle in prior as uniform | ||
def true_ellip_logprior(e_post: Array, g: Array, *, sigma_e: float): | ||
"""Implementation of GB's true prior on interim posterior samples of ellipticities.""" | ||
|
||
# for num, use trick | ||
# p(w_n' | g, alpha ) = p(w_n' \cross^{-1} g | alpha ) = p(w_n | alpha) * |jac(w_n / w_n')| | ||
|
||
# shape = (N, K, 2) | ||
# jacobian of inverse shear transformation | ||
grad1 = _grad_fnc1(e_post, g) | ||
grad2 = _grad_fnc2(e_post, g) | ||
absjacdet = jnp.abs(grad1[..., 0] * grad2[..., 1] - grad1[..., 1] * grad2[..., 0]) | ||
|
||
# true prior on unsheared ellipticity | ||
e_post_unsheared = _inv_shear_trans(e_post, g) | ||
e_post_unsheared_mag = norm(e_post_unsheared, axis=-1) | ||
num = _prior(e_post_unsheared_mag) * absjacdet # (N, K) | ||
prior_val = ellip_prior_e1e2(e_post_unsheared, sigma=sigma_e) | ||
|
||
return jnp.log(prior_val) + jnp.log(absjacdet) | ||
|
||
|
||
def shear_loglikelihood( | ||
g: Array, | ||
post_params: dict[str, Array], | ||
*, | ||
logprior: Callable, | ||
interim_logprior: Callable, # fixed | ||
) -> ArrayLike: | ||
"""Shear Likelihood implementation of Schneider et al. 2014.""" | ||
e_post = post_params["e1e2"] | ||
_, _, _ = e_post.shape # (N, K, 2) | ||
|
||
denom = interim_logprior(post_params) | ||
num = logprior(post_params, g) | ||
|
||
ratio = jsp.special.logsumexp(jnp.log(num) - jnp.log(denom), axis=-1) | ||
ratio = jsp.special.logsumexp(num - denom, axis=-1) | ||
return ratio.sum() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,69 @@ | ||
from functools import partial | ||
from typing import Callable | ||
|
||
from jax import Array | ||
from jax import jit as jjit | ||
import jax.numpy as jnp | ||
from jax import Array, jit | ||
from jax._src.prng import PRNGKeyArray | ||
from jax.scipy import stats | ||
|
||
from bpd.chains import run_inference_nuts | ||
from bpd.likelihood import shear_loglikelihood | ||
from bpd.prior import ellip_mag_prior | ||
from bpd.likelihood import shear_loglikelihood, true_ellip_logprior | ||
from bpd.prior import ellip_prior_e1e2 | ||
|
||
|
||
def logtarget_density(g: Array, *, data: Array, loglikelihood: Callable): | ||
loglike = loglikelihood(g, e_post=data) | ||
logprior = stats.uniform.logpdf(g, -0.1, 0.2).sum() | ||
def logtarget_density( | ||
g: Array, *, data: Array, loglikelihood: Callable, sigma_g: float = 0.01 | ||
): | ||
loglike = loglikelihood(g, post_params=data) | ||
logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum() | ||
return logprior + loglike | ||
|
||
|
||
def pipeline_shear_inference( | ||
def _logprior(post_params: dict[str, Array], g: Array, *, sigma_e: float): | ||
e_post = post_params["e1e2"] | ||
return true_ellip_logprior(e_post, g, sigma_e=sigma_e) | ||
|
||
|
||
def _interim_logprior(post_params: dict[str, Array], sigma_e_int: float): | ||
e_post = post_params["e1e2"] | ||
return jnp.log(ellip_prior_e1e2(e_post, sigma=sigma_e_int)) | ||
|
||
|
||
def pipeline_shear_inference_ellipticities( | ||
rng_key: PRNGKeyArray, | ||
e_post: Array, | ||
init_g: Array, | ||
*, | ||
true_g: Array, | ||
sigma_e: float, | ||
sigma_e_int: float, | ||
n_samples: int, | ||
initial_step_size: float, | ||
sigma_g: float = 0.01, | ||
n_warmup_steps: int = 500, | ||
max_num_doublings: int = 2, | ||
): | ||
interim_prior = partial(ellip_mag_prior, sigma=sigma_e_int) | ||
|
||
# NOTE: jit must be applied without `e_post` in partial! | ||
_loglikelihood = jjit( | ||
_loglikelihood = jit( | ||
partial( | ||
shear_loglikelihood, | ||
sigma_e=sigma_e, | ||
prior=ellip_mag_prior, | ||
interim_prior=interim_prior, | ||
logprior=partial(_logprior, sigma_e=sigma_e), | ||
interim_logprior=partial(_interim_logprior, sigma_e_int=sigma_e_int), | ||
) | ||
) | ||
_logtarget = partial(logtarget_density, loglikelihood=_loglikelihood) | ||
_logtarget = partial( | ||
logtarget_density, loglikelihood=_loglikelihood, sigma_g=sigma_g | ||
) | ||
|
||
_do_inference = partial( | ||
run_inference_nuts, | ||
data=e_post, | ||
data={"e1e2": e_post}, | ||
logtarget=_logtarget, | ||
n_samples=n_samples, | ||
n_warmup_steps=n_warmup_steps, | ||
max_num_doublings=max_num_doublings, | ||
initial_step_size=initial_step_size, | ||
) | ||
|
||
g_samples = _do_inference(rng_key, true_g) | ||
g_samples = _do_inference(rng_key, init_g) | ||
|
||
return g_samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.