Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(exp 3.2) demonstrate shear inference with all params free but narrow flux/hlr priors #68

Merged
merged 44 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
59bd535
prior from James
ismael-mendoza Dec 12, 2024
980afaa
start draft for exp4 with all params free
ismael-mendoza Dec 12, 2024
d616f38
Merge branch 'main' into shear-more-free-gparams1
ismael-mendoza Dec 30, 2024
44a16fa
need ipykernel
ismael-mendoza Dec 30, 2024
2857b97
add optional deviation from true centroid to draw function
ismael-mendoza Dec 31, 2024
c6f7ab8
split up flux and centroid prior evaluation
ismael-mendoza Dec 31, 2024
1f43ded
fix exp3
ismael-mendoza Dec 31, 2024
83d4255
for posterity
ismael-mendoza Dec 31, 2024
ed0f7cc
revert
ismael-mendoza Dec 31, 2024
6a4b817
fix exp2
ismael-mendoza Dec 31, 2024
606a7c1
add exp with centroid free
ismael-mendoza Dec 31, 2024
b4f77ce
rename folder
ismael-mendoza Dec 31, 2024
5926f3d
use log hlr so always positive no hard bounds
ismael-mendoza Dec 31, 2024
8fd717a
draft exp32 getting samples from galaxies
ismael-mendoza Dec 31, 2024
4bdcf98
adjusting shear inference computation
ismael-mendoza Dec 31, 2024
283e5c0
readme
ismael-mendoza Dec 31, 2024
c76543c
draft more scripts
ismael-mendoza Jan 1, 2025
c4ed9aa
demonstrate scaling of HLR and flux
ismael-mendoza Jan 3, 2025
4a4c47f
might need to rethink draw
ismael-mendoza Jan 3, 2025
0ec8a09
draft accounting for magnification
ismael-mendoza Jan 3, 2025
0b8f33a
rename folder
ismael-mendoza Jan 6, 2025
20777c1
remove dxdy as can easily cause bugs
ismael-mendoza Jan 6, 2025
eb5192d
manipulate parameters inside loglikelihood to allow for dxdy
ismael-mendoza Jan 6, 2025
4bfb21c
update tile in README
ismael-mendoza Jan 6, 2025
8933f46
update with new code structure, also make sure that we agree with gal…
ismael-mendoza Jan 6, 2025
a9555a6
fix bug
ismael-mendoza Jan 6, 2025
141339b
update exp30 and rerun
ismael-mendoza Jan 6, 2025
04ed5c5
space
ismael-mendoza Jan 6, 2025
07956bb
update and rerun exp31
ismael-mendoza Jan 6, 2025
d1455ed
comment update
ismael-mendoza Jan 6, 2025
84c3371
update exp32 and add magnification jacobian
ismael-mendoza Jan 6, 2025
783bffd
no need for this check
ismael-mendoza Jan 6, 2025
0e59710
minor fixes
ismael-mendoza Jan 6, 2025
12d33ca
fix implementation
ismael-mendoza Jan 6, 2025
4d13d55
add figures
ismael-mendoza Jan 6, 2025
7e8e026
add seeds for reproducibility of figures
ismael-mendoza Jan 6, 2025
08838ac
updates
ismael-mendoza Jan 6, 2025
067acb6
Merge branch 'main' into more-free-gparams2
ismael-mendoza Jan 7, 2025
6b18ba8
comment
ismael-mendoza Jan 7, 2025
354da95
fix draw
ismael-mendoza Jan 7, 2025
102b852
fix test but some questions remain about precision
ismael-mendoza Jan 7, 2025
2ad2081
tolerance
ismael-mendoza Jan 7, 2025
8baad75
fix exp2 so it runs as before
ismael-mendoza Jan 7, 2025
f369b91
fix
ismael-mendoza Jan 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ git clone [email protected]:LSSTDESC/BPD.git
cd BPD
pip install -e .
pip install -e ".[dev]"

# Might be necessary
conda install -c nvidia cuda-nvcc
```
19 changes: 10 additions & 9 deletions bpd/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,24 @@
from jax_galsim import GSParams


# forward model
def draw_gaussian(
*,
f: float,
hlr: float,
e1: float,
e2: float,
g1: float,
g2: float,
x: float,
x: float, # pixels
y: float,
*,
slen: int,
fft_size: int, # rule of thumb: at least 4 times `slen`
psf_hlr: float = 0.7,
pixel_scale: float = 0.2,
):
gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)

# x, y arguments in pixels
gal = xgalsim.Gaussian(flux=f, half_light_radius=hlr)
gal = gal.shear(g1=e1, g2=e2)
gal = gal.shear(g1=g1, g2=g2)

psf = xgalsim.Gaussian(flux=1.0, half_light_radius=psf_hlr)
gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(gsparams)
Expand All @@ -32,6 +29,7 @@ def draw_gaussian(


def draw_gaussian_galsim(
*,
f: float,
hlr: float,
e1: float,
Expand All @@ -40,14 +38,17 @@ def draw_gaussian_galsim(
g2: float,
x: float, # pixels
y: float,
*,
slen: int,
psf_hlr: float = 0.7,
pixel_scale: float = 0.2,
):
gal = galsim.Gaussian(flux=f, half_light_radius=hlr)
gal = gal.shear(g1=e1, g2=e2)
gal = gal.shear(g1=g1, g2=g2)
gal = gal.shear(g1=e1, g2=e2) # intrinsic ellipticity

# the correct weak lensing effect includes magnification even if kappa=0!
# see: https://galsim-developers.github.io/GalSim/_build/html/shear.html
mu = (1 - g1**2 - g2**2) ** -1 # convergence kappa = 0
gal = gal.lens(g1=g1, g2=g2, mu=mu)

psf = galsim.Gaussian(flux=1.0, half_light_radius=psf_hlr)
gal_conv = galsim.Convolve([gal, psf])
Expand Down
4 changes: 0 additions & 4 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ def shear_loglikelihood(
interim_logprior: Callable, # fixed
) -> ArrayLike:
"""Shear Likelihood implementation of Schneider et al. 2014."""
e_post = post_params["e1e2"]
_, _, _ = e_post.shape # (N, K, 2)

denom = interim_logprior(post_params)
num = logprior(post_params, g)

ratio = jsp.special.logsumexp(num - denom, axis=-1)
return ratio.sum()
90 changes: 52 additions & 38 deletions bpd/pipelines/image_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,43 @@ def sample_target_galaxy_params_simple(
}


def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation(
jnp.array([e1, e2]), jnp.array([g1, g2])
)
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime

return true_params # don't add back g1,g2 as we are not inferring those in interim posterior


# interim prior
def logprior(
params: dict[str, Array],
*,
sigma_e: float,
sigma_x: float = 0.5, # pixels
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
all_free: bool = True,
hlr_bds: tuple = (-2.0, 1.0),
free_flux_hlr: bool = True,
free_dxdy: bool = True,
) -> Array:
prior = jnp.array(0.0)

if all_free:
if free_flux_hlr:
f1, f2 = flux_bds
prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1)

h1, h2 = hlr_bds
prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1)
prior += stats.uniform.logpdf(params["lhlr"], h1, h2 - h1)

# NOTE: hard-coded assumption that galaxy is in center-pixel within odd-size image.
# sigma_x in units of pixels.
prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)
if free_dxdy:
prior += stats.norm.logpdf(params["dx"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["dy"], loc=0.0, scale=sigma_x)

e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1)
prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e))
Expand All @@ -76,17 +90,33 @@ def logprior(
def loglikelihood(
params: dict[str, Array],
data: Array,
fixed_params: dict[str, Array],
*,
draw_fnc: Callable,
background: float,
free_flux: bool = True,
free_flux_hlr: bool = True,
free_dxdy: bool = True,
):
# NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments if fixed
_draw_params = {**{"g1": 0.0, "g2": 0.0}, **params}
_draw_params = {}

if free_dxdy:
_draw_params["x"] = params["dx"] + fixed_params["x"]
_draw_params["y"] = params["dy"] + fixed_params["y"]

else:
_draw_params["x"] = fixed_params["x"]
_draw_params["y"] = fixed_params["y"]

if free_flux_hlr:
_draw_params["f"] = 10 ** params["lf"]
_draw_params["hlr"] = 10 ** params["lhlr"]

# Convert log-flux to flux if provided
if free_flux:
_draw_params["f"] = 10 ** _draw_params.pop("lf")
else:
_draw_params["f"] = fixed_params["f"]
_draw_params["hlr"] = fixed_params["hlr"]

_draw_params["e1"] = params["e1"]
_draw_params["e2"] = params["e2"]

model = draw_fnc(**_draw_params)
likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
Expand All @@ -97,24 +127,11 @@ def logtarget(
params: dict[str, Array],
data: Array,
*,
fixed_params: dict[str, Array],
logprior_fnc: Callable,
loglikelihood_fnc: Callable,
):
return logprior_fnc(params) + loglikelihood_fnc(params, data)


def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation(
jnp.array([e1, e2]), jnp.array([g1, g2])
)
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime

return true_params # don't add g1,g2 back as we are not inferring those in interim posterior
return logprior_fnc(params) + loglikelihood_fnc(params, data, fixed_params)


def get_target_images_single(
Expand Down Expand Up @@ -156,30 +173,27 @@ def pipeline_interim_samples_one_galaxy(
rng_key: PRNGKeyArray,
true_params: dict[str, float],
target_image: Array,
fixed_draw_kwargs: dict,
fixed_params: dict[str, float],
*,
initialization_fnc: Callable,
draw_fnc: Callable,
logprior: Callable,
loglikelihood: Callable,
n_samples: int = 100,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = True,
background: float = 1.0,
free_flux: bool = True,
):
# Flux and HLR are fixed to truth and not inferred in this function.
k1, k2 = random.split(rng_key)

init_position = initialization_fnc(k1, true_params=true_params, data=target_image)
_draw_fnc = partial(draw_fnc, **fixed_draw_kwargs)
_loglikelihood = partial(
loglikelihood, draw_fnc=_draw_fnc, background=background, free_flux=free_flux
)

_logtarget = partial(
logtarget, logprior_fnc=logprior, loglikelihood_fnc=_loglikelihood
logtarget,
logprior_fnc=logprior,
loglikelihood_fnc=loglikelihood,
fixed_params=fixed_params,
)

_inference_fnc = partial(
Expand Down
59 changes: 54 additions & 5 deletions experiments/exp2/run_inference_galaxy_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import time
from functools import partial
from typing import Callable

import jax.numpy as jnp
import typer
from jax import jit, random, vmap
from jax import Array, jit, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd import DATA_DIR
from bpd.chains import run_sampling_nuts, run_warmup_nuts
Expand All @@ -16,11 +18,58 @@
from bpd.pipelines.image_samples import (
get_target_images,
get_true_params_from_galaxy_params,
loglikelihood,
logprior,
logtarget,
sample_target_galaxy_params_simple,
)
from bpd.prior import ellip_prior_e1e2


def logprior(
params: dict[str, Array],
*,
sigma_e: float,
sigma_x: float = 0.5, # pixels
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
) -> Array:
prior = jnp.array(0.0)

f1, f2 = flux_bds
prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1)

h1, h2 = hlr_bds
prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1)

prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)

e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1)
prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e))

return prior


def loglikelihood(
params: dict[str, Array],
data: Array,
*,
draw_fnc: Callable,
background: float,
):
_draw_params = {**params}
_draw_params["f"] = 10 ** _draw_params.pop("lf")
model = draw_fnc(**_draw_params)
likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
return jnp.sum(likelihood_pp)


def logtarget(
params: dict[str, Array],
data: Array,
*,
logprior_fnc: Callable,
loglikelihood_fnc: Callable,
):
return logprior_fnc(params) + loglikelihood_fnc(params, data)


def sample_prior(
Expand Down Expand Up @@ -67,7 +116,7 @@ def main(
pkey, nkey, ikey, rkey = random.split(rng_key, 4)

# directory structure
dirpath = DATA_DIR / "cache_chains" / f"test_image_sampling_{seed}"
dirpath = DATA_DIR / "cache_chains" / f"exp2_{seed}"
if not dirpath.exists():
dirpath.mkdir(exist_ok=True)
fpath = dirpath / f"chain_results_{seed}.npy"
Expand Down
Binary file removed experiments/exp3/figs/contours.pdf
Binary file not shown.
Binary file removed experiments/exp3/figs/scatter_shapes.pdf
Binary file not shown.
Binary file removed experiments/exp3/figs/traces.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion experiments/exp3/README.md → experiments/exp30/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Experiment 3
# Experiment 3.0

This folder contains scripts to reproduce results on fitting galaxy images in the low noise setting
and inferring shear from them. Details:
Expand Down
Binary file added experiments/exp30/figs/contours.pdf
Binary file not shown.
Binary file added experiments/exp30/figs/hists.pdf
Binary file not shown.
Binary file added experiments/exp30/figs/scatter_shapes.pdf
Binary file not shown.
Binary file added experiments/exp30/figs/traces.pdf
Binary file not shown.
File renamed without changes.
Loading
Loading