Skip to content

Commit

Permalink
it's useful to keep the simple pipeline that only does ellipticities
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Jan 14, 2025
1 parent 0366b9e commit 19f4432
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
2 changes: 1 addition & 1 deletion bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def shear_loglikelihood(
g: Array,
post_params: dict[str, Array],
post_params: dict[str, Array] | Array,
*,
logprior: Callable,
interim_logprior: Callable, # fixed
Expand Down
44 changes: 42 additions & 2 deletions bpd/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

from bpd.chains import run_inference_nuts
from bpd.likelihood import shear_loglikelihood
from bpd.prior import ellip_prior_e1e2
from bpd.prior import ellip_prior_e1e2, true_ellip_logprior
from bpd.sample import sample_noisy_ellipticities_unclipped


def logtarget_shear(g: Array, *, data: Array, loglikelihood: Callable, sigma_g: float):
def logtarget_shear(
g: Array, *, data: Array | dict[str, Array], loglikelihood: Callable, sigma_g: float
):
loglike = loglikelihood(g, post_params=data)
logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum()
return logprior + loglike
Expand Down Expand Up @@ -55,6 +57,44 @@ def pipeline_shear_inference(
return g_samples


def pipeline_shear_inference_simple(
rng_key: PRNGKeyArray,
e_post: Array,
init_g: Array,
*,
sigma_e: float,
sigma_e_int: float,
n_samples: int,
initial_step_size: float,
sigma_g: float = 0.01,
n_warmup_steps: int = 500,
max_num_doublings: int = 2,
):
_logprior = lambda e, g: true_ellip_logprior(e, g, sigma_e=sigma_e)
_interim_logprior = lambda e: jnp.log(ellip_prior_e1e2(e, sigma=sigma_e_int))

_loglikelihood = partial(
shear_loglikelihood, logprior=_logprior, interim_logprior=_interim_logprior
)
_loglikelihood_jitted = jit(_loglikelihood)

_logtarget = partial(
logtarget_shear, loglikelihood=_loglikelihood_jitted, sigma_g=sigma_g
)

_do_inference = partial(
run_inference_nuts,
data=e_post,
logtarget=_logtarget,
n_samples=n_samples,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
)

return _do_inference(rng_key, init_g)


def logtarget_images(
params: dict[str, Array],
data: Array,
Expand Down

0 comments on commit 19f4432

Please sign in to comment.