Skip to content

Commit

Permalink
change to use galaxies of different ellipticities
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Nov 11, 2024
1 parent ce8c066 commit 5fadf80
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions scripts/get_image_interim_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from bpd.io import save_dataset
from bpd.pipelines.image_ellips import (
get_target_galaxy_params_simple,
get_target_images_single,
get_target_images,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples_one_galaxy,
)
Expand Down Expand Up @@ -57,25 +57,19 @@ def main(

# now get corresponding target images
# 'lf' is used for inference, but 'f' is used for drawing
nkeys = random.split(nkey, n_gals)
draw_params = {**galaxy_params}
draw_params["f"] = 10 ** draw_params.pop("lf")

_get_target_images_single = partial(
get_target_images_single,
n_samples=1, # one noise realization per galaxy
background=background,
slen=slen,
target_images = get_target_images(
nkey, draw_params, background=background, slen=slen
)
_get_target_images = vmap(_get_target_images_single)
target_images = _get_target_images(nkeys, draw_params)
assert target_images.shape == (n_gals, slen, slen)

# finally, interim samples are on 'sheared ellipticity'
_get_true_params = vmap(get_true_params_from_galaxy_params)
true_params = _get_true_params(galaxy_params)

# prepare pipelines
gkeys = random.split(gkey, n_gals)
pipe = partial(
pipeline_image_interim_samples_one_galaxy,
initialization_fnc=INIT_FNC,
Expand All @@ -88,9 +82,6 @@ def main(
)
vpipe = vmap(jjit(pipe), (0, 0, 0))

# initialization
gkeys = random.split(gkey, n_gals)

# compilation on single target image
_ = vpipe(
gkeys[0, None],
Expand Down

0 comments on commit 5fadf80

Please sign in to comment.