Skip to content

Commit

Permalink
use generalized version
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Nov 16, 2024
1 parent 9b9c37e commit 0a34c5b
Showing 1 changed file with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from jax import random, vmap

from bpd import DATA_DIR
from bpd.draw import draw_gaussian
from bpd.initialization import init_with_truth
from bpd.io import save_dataset
from bpd.pipelines.image_samples_fixed_flux import (
get_target_galaxy_params_simple,
from bpd.pipelines.image_samples import (
get_target_images,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples_one_galaxy,
pipeline_interim_samples_one_galaxy,
)
from bpd.pipelines.image_samples_fixed_flux import (
get_target_galaxy_params_simple,
logprior,
)

INIT_FNC = init_with_truth
Expand Down Expand Up @@ -52,6 +56,7 @@ def main(
)
galaxy_params = vmap(_get_galaxy_params)(pkeys)
assert galaxy_params["x"].shape == (n_gals,)
assert "lf" not in galaxy_params and "hlr" not in galaxy_params # not inferring

# now get corresponding target images
# we use the same flux and hlr for every galaxy in this experiment (and fix them in sampling)
Expand All @@ -67,17 +72,18 @@ def main(

# prepare pipelines
gkeys = random.split(gkey, n_gals)
_draw_fnc = partial(draw_gaussian, f=10**lf, hlr=hlr, slen=slen, fft_size=fft_size)
pipe = partial(
pipeline_image_interim_samples_one_galaxy,
pipeline_interim_samples_one_galaxy,
initialization_fnc=INIT_FNC,
draw_fnc=_draw_fnc,
logprior=logprior,
sigma_e_int=sigma_e_int,
n_samples=n_samples_per_gal,
initial_step_size=initial_step_size,
slen=slen,
fft_size=fft_size,
background=background,
f=10**lf,
hlr=hlr,
)
vpipe = vmap(jjit(pipe), (0, 0, 0))

Expand Down

0 comments on commit 0a34c5b

Please sign in to comment.