diff --git a/bpd/likelihood.py b/bpd/likelihood.py index 2b93d9d..efbed1b 100644 --- a/bpd/likelihood.py +++ b/bpd/likelihood.py @@ -9,7 +9,7 @@ def shear_loglikelihood( g: Array, - post_params: dict[str, Array], + post_params: dict[str, Array] | Array, *, logprior: Callable, interim_logprior: Callable, # fixed diff --git a/bpd/pipelines.py b/bpd/pipelines.py index af9e97a..ae1e4a4 100644 --- a/bpd/pipelines.py +++ b/bpd/pipelines.py @@ -9,11 +9,13 @@ from bpd.chains import run_inference_nuts from bpd.likelihood import shear_loglikelihood -from bpd.prior import ellip_prior_e1e2 +from bpd.prior import ellip_prior_e1e2, true_ellip_logprior from bpd.sample import sample_noisy_ellipticities_unclipped -def logtarget_shear(g: Array, *, data: Array, loglikelihood: Callable, sigma_g: float): +def logtarget_shear( + g: Array, *, data: Array | dict[str, Array], loglikelihood: Callable, sigma_g: float +): loglike = loglikelihood(g, post_params=data) logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum() return logprior + loglike @@ -55,6 +57,44 @@ def pipeline_shear_inference( return g_samples +def pipeline_shear_inference_simple( + rng_key: PRNGKeyArray, + e_post: Array, + init_g: Array, + *, + sigma_e: float, + sigma_e_int: float, + n_samples: int, + initial_step_size: float, + sigma_g: float = 0.01, + n_warmup_steps: int = 500, + max_num_doublings: int = 2, +): + _logprior = lambda e, g: true_ellip_logprior(e, g, sigma_e=sigma_e) + _interim_logprior = lambda e: jnp.log(ellip_prior_e1e2(e, sigma=sigma_e_int)) + + _loglikelihood = partial( + shear_loglikelihood, logprior=_logprior, interim_logprior=_interim_logprior + ) + _loglikelihood_jitted = jit(_loglikelihood) + + _logtarget = partial( + logtarget_shear, loglikelihood=_loglikelihood_jitted, sigma_g=sigma_g + ) + + _do_inference = partial( + run_inference_nuts, + data=e_post, + logtarget=_logtarget, + n_samples=n_samples, + n_warmup_steps=n_warmup_steps, + max_num_doublings=max_num_doublings, + initial_step_size=initial_step_size, + ) + + return _do_inference(rng_key, init_g) + + def logtarget_images( params: dict[str, Array], data: Array,