diff --git a/scripts/get_image_interim_samples_fixed1.py b/scripts/get_image_interim_samples_fixed_flux.py similarity index 89% rename from scripts/get_image_interim_samples_fixed1.py rename to scripts/get_image_interim_samples_fixed_flux.py index 649c154..86fe345 100755 --- a/scripts/get_image_interim_samples_fixed1.py +++ b/scripts/get_image_interim_samples_fixed_flux.py @@ -8,13 +8,17 @@ from jax import random, vmap from bpd import DATA_DIR +from bpd.draw import draw_gaussian from bpd.initialization import init_with_truth from bpd.io import save_dataset -from bpd.pipelines.image_samples_fixed_flux import ( - get_target_galaxy_params_simple, +from bpd.pipelines.image_samples import ( get_target_images, get_true_params_from_galaxy_params, - pipeline_image_interim_samples_one_galaxy, + pipeline_interim_samples_one_galaxy, +) +from bpd.pipelines.image_samples_fixed_flux import ( + get_target_galaxy_params_simple, + logprior, ) INIT_FNC = init_with_truth @@ -52,6 +56,7 @@ def main( ) galaxy_params = vmap(_get_galaxy_params)(pkeys) assert galaxy_params["x"].shape == (n_gals,) + assert "lf" not in galaxy_params and "hlr" not in galaxy_params # not inferring # now get corresponding target images # we use the same flux and hlr for every galaxy in this experiment (and fix them in sampling) @@ -67,17 +72,18 @@ def main( # prepare pipelines gkeys = random.split(gkey, n_gals) + _draw_fnc = partial(draw_gaussian, f=10**lf, hlr=hlr, slen=slen, fft_size=fft_size) pipe = partial( - pipeline_image_interim_samples_one_galaxy, + pipeline_interim_samples_one_galaxy, initialization_fnc=INIT_FNC, + draw_fnc=_draw_fnc, + logprior=logprior, sigma_e_int=sigma_e_int, n_samples=n_samples_per_gal, initial_step_size=initial_step_size, slen=slen, fft_size=fft_size, background=background, - f=10**lf, - hlr=hlr, ) vpipe = vmap(jjit(pipe), (0, 0, 0))