Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(exp 3.1) Shear Inference on galaxy images with centroid free #63

Closed
wants to merge 11 commits into from
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ git clone [email protected]:LSSTDESC/BPD.git
cd BPD
pip install -e .
pip install -e ".[dev]"

# Might be necessary
conda install -c nvidia cuda-nvcc
```
6 changes: 5 additions & 1 deletion bpd/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def draw_gaussian(
fft_size: int, # rule of thumb: at least 4 times `slen`
psf_hlr: float = 0.7,
pixel_scale: float = 0.2,
dx=0.0, # additional offset from true centroid
dy=0.0,
):
gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)

Expand All @@ -27,7 +29,9 @@ def draw_gaussian(

psf = xgalsim.Gaussian(flux=1.0, half_light_radius=psf_hlr)
gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(gsparams)
image = gal_conv.drawImage(nx=slen, ny=slen, scale=pixel_scale, offset=(x, y))
image = gal_conv.drawImage(
nx=slen, ny=slen, scale=pixel_scale, offset=(x + dx, y + dy)
)
return image.array


Expand Down
12 changes: 6 additions & 6 deletions bpd/pipelines/image_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,21 @@ def logprior(
sigma_x: float = 0.5, # pixels
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
all_free: bool = True,
free_flux_hlr: bool = True,
free_dxdy: bool = True,
) -> Array:
prior = jnp.array(0.0)

if all_free:
if free_flux_hlr:
f1, f2 = flux_bds
prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1)

h1, h2 = hlr_bds
prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1)

# NOTE: hard-coded assumption that galaxy is in center-pixel within odd-size image.
# sigma_x in units of pixels.
prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)
if free_dxdy:
prior += stats.norm.logpdf(params["dx"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["dy"], loc=0.0, scale=sigma_x)

e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1)
prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e))
Expand Down
30 changes: 28 additions & 2 deletions experiments/exp2/run_inference_galaxy_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import jax.numpy as jnp
import typer
from jax import jit, random, vmap
from jax import Array, jit, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd import DATA_DIR
from bpd.chains import run_sampling_nuts, run_warmup_nuts
Expand All @@ -17,10 +18,35 @@
get_target_images,
get_true_params_from_galaxy_params,
loglikelihood,
logprior,
logtarget,
sample_target_galaxy_params_simple,
)
from bpd.prior import ellip_prior_e1e2


def logprior(
params: dict[str, Array],
*,
sigma_e: float,
sigma_x: float = 0.5, # pixels
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
) -> Array:
prior = jnp.array(0.0)

f1, f2 = flux_bds
prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1)

h1, h2 = hlr_bds
prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1)

prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)

e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1)
prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e))

return prior


def sample_prior(
Expand Down
4 changes: 3 additions & 1 deletion experiments/exp3/get_image_interim_samples_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def main(
extra_params = {"x": x, "y": y, **extra_params}

# more setup
_logprior = partial(logprior, sigma_e=sigma_e_int, all_free=False)
_logprior = partial(
logprior, sigma_e=sigma_e_int, free_flux_hlr=False, free_dxdy=False
)

# prepare pipelines
gkeys = random.split(gkey, n_gals)
Expand Down
8 changes: 8 additions & 0 deletions experiments/exp4/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Experiment 4

Same as experiment 3 but we let deviations from true centroid `(dx, dy)` be free parameters in
the fit.


For now we assume that the true and interim prior on the deviation is the same, so it does not get
accounted for in the shear inference.
Binary file added experiments/exp4/figs/contours.pdf
Binary file not shown.
Binary file added experiments/exp4/figs/hists.pdf
Binary file not shown.
Binary file added experiments/exp4/figs/scatter_dxdy.pdf
Binary file not shown.
Binary file added experiments/exp4/figs/scatter_shapes.pdf
Binary file not shown.
Binary file added experiments/exp4/figs/traces.pdf
Binary file not shown.
122 changes: 122 additions & 0 deletions experiments/exp4/get_interim_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python3
from functools import partial

import jax.numpy as jnp
import typer
from jax import jit, random, vmap

from bpd import DATA_DIR
from bpd.draw import draw_gaussian
from bpd.initialization import init_with_truth
from bpd.io import save_dataset
from bpd.pipelines.image_samples import (
get_target_images,
get_true_params_from_galaxy_params,
logprior,
pipeline_interim_samples_one_galaxy,
sample_target_galaxy_params_simple,
)


def main(
seed: int,
n_gals: int = 1000, # technically, in this file it means 'noise realizations'
n_samples_per_gal: int = 100,
g1: float = 0.02,
g2: float = 0.0,
lf: float = 6.0, # ~ SNR = 1000
hlr: float = 1.0,
shape_noise: float = 1e-4,
sigma_e_int: float = 4e-2,
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
initial_step_size: float = 1e-3,
):
rng_key = random.key(seed)
pkey, nkey, gkey = random.split(rng_key, 3)

# directory structure
dirpath = DATA_DIR / "cache_chains" / f"exp4_{seed}"
if not dirpath.exists():
dirpath.mkdir(exist_ok=True)

# galaxy parameters from prior
pkeys = random.split(pkey, n_gals)
_get_galaxy_params = partial(
sample_target_galaxy_params_simple, g1=g1, g2=g2, shape_noise=shape_noise
)
galaxy_params = vmap(_get_galaxy_params)(pkeys)
assert galaxy_params["x"].shape == (n_gals,)
assert galaxy_params["e1"].shape == (n_gals,)
assert "lf" not in galaxy_params
extra_params = {"f": 10 ** jnp.full((n_gals,), lf), "hlr": jnp.full((n_gals,), hlr)}

# now get corresponding target images
draw_params = {**galaxy_params, **extra_params}
target_images = get_target_images(
nkey, draw_params, background=background, slen=slen
)
assert target_images.shape == (n_gals, slen, slen)

# interim samples are on 'sheared ellipticity'
true_params = vmap(get_true_params_from_galaxy_params)(galaxy_params)

# we pass in x,y as fixed parameters for drawing
# and initialize the function with deviations (dx, dy) = (0, 0)
x = true_params.pop("x")
y = true_params.pop("y")
true_params["dx"] = jnp.zeros_like(x)
true_params["dy"] = jnp.zeros_like(y)
extra_params = {"x": x, "y": y, **extra_params}

# more setup
_logprior = partial(
logprior, sigma_e=sigma_e_int, free_flux_hlr=False, free_dxdy=True
)

# prepare pipelines
gkeys = random.split(gkey, n_gals)
_draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size)
pipe = partial(
pipeline_interim_samples_one_galaxy,
initialization_fnc=init_with_truth,
draw_fnc=_draw_fnc,
logprior=_logprior,
n_samples=n_samples_per_gal,
initial_step_size=initial_step_size,
background=background,
free_flux=False,
)
vpipe = vmap(jit(pipe))

# compilation on single target image
_ = vpipe(
gkeys[0, None],
{k: v[0, None] for k, v in true_params.items()},
target_images[0, None],
{k: v[0, None] for k, v in extra_params.items()},
)

samples = vpipe(gkeys, true_params, target_images, extra_params)
e_post = jnp.stack([samples["e1"], samples["e2"]], axis=-1)
fpath = dirpath / f"e_post_{seed}.npz"

save_dataset(
{
"e_post": e_post,
"true_g": jnp.array([g1, g2]),
"dx": samples["dx"],
"dy": samples["dy"],
"sigma_e": shape_noise,
"sigma_e_int": sigma_e_int,
"e1": draw_params["e1"],
"e2": draw_params["e2"],
},
fpath,
overwrite=True,
)


if __name__ == "__main__":
typer.run(main)
7 changes: 7 additions & 0 deletions experiments/exp4/get_posteriors.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES="0"
export JAX_ENABLE_X64="True"
SEED="43"

./get_interim_samples.py $SEED
../../scripts/get_shear_from_interim_samples.py $SEED exp4_$SEED "e_post_${SEED}.npz" --overwrite
Loading
Loading