Skip to content

Commit

Permalink
much less needs to be added see
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Jan 14, 2025
1 parent 7444a5d commit 061d14f
Showing 1 changed file with 5 additions and 54 deletions.
59 changes: 5 additions & 54 deletions experiments/exp32/get_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,17 @@
"""This file creates toy samples of ellipticities and saves them to .hdf5 file."""

from functools import partial
from typing import Callable

import jax
import jax.numpy as jnp
import typer
from jax import Array, jit
from jax._src.prng import PRNGKeyArray
from jax import Array
from jax.scipy import stats

from bpd import DATA_DIR
from bpd.chains import run_inference_nuts
from bpd.io import load_dataset
from bpd.likelihood import shear_loglikelihood, true_ellip_logprior
from bpd.pipelines.image_samples import logprior


def logtarget_density(
g: Array, *, data: Array, loglikelihood: Callable, sigma_g: float = 0.01
):
loglike = loglikelihood(g, post_params=data)
logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum()
return logprior + loglike
from bpd.pipelines import pipeline_shear_inference
from bpd.prior import interim_gprops_logprior, true_ellip_logprior


def _logprior(
Expand Down Expand Up @@ -54,49 +43,11 @@ def _logprior(

def _interim_logprior(post_params: dict[str, Array], sigma_e_int: float):
# we do not evaluate dxdy as we assume it's the same as the true prior and they cancel
return logprior(
return interim_gprops_logprior(
post_params, sigma_e=sigma_e_int, free_flux_hlr=True, free_dxdy=False
)


def pipeline_shear_inference(
rng_key: PRNGKeyArray,
post_params: Array,
init_g: Array,
*,
logprior: Callable,
interim_logprior: Callable,
n_samples: int,
initial_step_size: float,
sigma_g: float = 0.01,
n_warmup_steps: int = 500,
max_num_doublings: int = 2,
):
# NOTE: jit must be applied without `e_post` in partial!
_loglikelihood = jit(
partial(
shear_loglikelihood, logprior=logprior, interim_logprior=interim_logprior
)
)
_logtarget = partial(
logtarget_density, loglikelihood=_loglikelihood, sigma_g=sigma_g
)

_do_inference = partial(
run_inference_nuts,
data=post_params,
logtarget=_logtarget,
n_samples=n_samples,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
)

g_samples = _do_inference(rng_key, init_g)

return g_samples


def main(
seed: int,
initial_step_size: float = 1e-3,
Expand Down Expand Up @@ -147,9 +98,9 @@ def main(
g_samples = pipeline_shear_inference(
rng_key,
post_params,
init_g=true_g,
logprior=logprior_fnc,
interim_logprior=interim_logprior_fnc,
init_g=true_g,
n_samples=n_samples,
initial_step_size=initial_step_size,
)
Expand Down

0 comments on commit 061d14f

Please sign in to comment.