Skip to content

Commit

Permalink
Merge branch 'main' into spergel-investigation1
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza authored Jan 8, 2025
2 parents 3e647c1 + c72f59e commit 72bc24e
Show file tree
Hide file tree
Showing 49 changed files with 2,207 additions and 85 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ git clone [email protected]:LSSTDESC/BPD.git
cd BPD
pip install -e .
pip install -e ".[dev]"

# Might be necessary
conda install -c nvidia cuda-nvcc
```
19 changes: 10 additions & 9 deletions bpd/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,24 @@
from jax_galsim import GSParams


# forward model
def draw_gaussian(
*,
f: float,
hlr: float,
e1: float,
e2: float,
g1: float,
g2: float,
x: float,
x: float, # pixels
y: float,
*,
slen: int,
fft_size: int, # rule of thumb: at least 4 times `slen`
psf_hlr: float = 0.7,
pixel_scale: float = 0.2,
):
gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)

# x, y arguments in pixels
gal = xgalsim.Gaussian(flux=f, half_light_radius=hlr)
gal = gal.shear(g1=e1, g2=e2)
gal = gal.shear(g1=g1, g2=g2)

psf = xgalsim.Gaussian(flux=1.0, half_light_radius=psf_hlr)
gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(gsparams)
Expand All @@ -32,6 +29,7 @@ def draw_gaussian(


def draw_gaussian_galsim(
*,
f: float,
hlr: float,
e1: float,
Expand All @@ -40,14 +38,17 @@ def draw_gaussian_galsim(
g2: float,
x: float, # pixels
y: float,
*,
slen: int,
psf_hlr: float = 0.7,
pixel_scale: float = 0.2,
):
gal = galsim.Gaussian(flux=f, half_light_radius=hlr)
gal = gal.shear(g1=e1, g2=e2)
gal = gal.shear(g1=g1, g2=g2)
gal = gal.shear(g1=e1, g2=e2) # intrinsic ellipticity

# the correct weak lensing effect includes magnification even if kappa=0!
# see: https://galsim-developers.github.io/GalSim/_build/html/shear.html
mu = (1 - g1**2 - g2**2) ** -1 # convergence kappa = 0
gal = gal.lens(g1=g1, g2=g2, mu=mu)

psf = galsim.Gaussian(flux=1.0, half_light_radius=psf_hlr)
gal_conv = galsim.Convolve([gal, psf])
Expand Down
59 changes: 59 additions & 0 deletions bpd/jackknife.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from math import ceil
from typing import Callable

import jax.numpy as jnp
from jax import Array, random
from tqdm import tqdm


def run_jackknife_shear_pipeline(
rng_key,
init_g: Array,
post_params_pos: dict,
post_params_neg: dict,
shear_pipeline: Callable,
n_jacks: int = 10,
disable_bar: bool = True,
):
"""Use jackknife+shape noise cancellation to estimate the mean and std of the shear posterior.
Args:
rng_key: Random jax key.
init_g: Initial value for shear `g`.
post_params_pos: Interim posterior galaxy parameters estimated using positive shear.
post_params_neg: Interim posterior galaxy parameters estimated using negative shear,
and otherwise same conditions and random seed as `post_params_pos`.
shear_pipeline: Function that outputs shear posterior samples from `post_params` with all
keyword arguments pre-specified.
n_jacks: Number of jackknife batches.
Returns:
Jackknife
"""
N, _ = post_params_pos["e1"].shape # N = n_gals, K = n_samples_per_gal
batch_size = ceil(N / n_jacks)

g_best_list = []
keys = random.split(rng_key, n_jacks)

for ii in tqdm(range(n_jacks), desc="Jackknife #", disable=disable_bar):
k_ii = keys[ii]
start, end = ii * batch_size, (ii + 1) * batch_size

_params_jack_pos = {
k: jnp.concatenate([v[:start], v[end:]]) for k, v in post_params_pos.items()
}
_params_jack_neg = {
k: jnp.concatenate([v[:start], v[end:]]) for k, v in post_params_neg.items()
}

g_pos_ii = shear_pipeline(k_ii, _params_jack_pos, init_g)
g_neg_ii = shear_pipeline(k_ii, _params_jack_neg, -init_g)
g_best_ii = (g_pos_ii - g_neg_ii) * 0.5
g_best_mean_ii = g_best_ii.mean(axis=0)

g_best_list.append(g_best_mean_ii)

g_best_means = jnp.array(g_best_list)
return g_best_means
4 changes: 0 additions & 4 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ def shear_loglikelihood(
interim_logprior: Callable, # fixed
) -> ArrayLike:
"""Shear Likelihood implementation of Schneider et al. 2014."""
e_post = post_params["e1e2"]
_, _, _ = e_post.shape # (N, K, 2)

denom = interim_logprior(post_params)
num = logprior(post_params, g)

ratio = jsp.special.logsumexp(num - denom, axis=-1)
return ratio.sum()
90 changes: 52 additions & 38 deletions bpd/pipelines/image_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,43 @@ def sample_target_galaxy_params_simple(
}


def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation(
jnp.array([e1, e2]), jnp.array([g1, g2])
)
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime

return true_params # don't add back g1,g2 as we are not inferring those in interim posterior


# interim prior
def logprior(
params: dict[str, Array],
*,
sigma_e: float,
sigma_x: float = 0.5, # pixels
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
all_free: bool = True,
hlr_bds: tuple = (-2.0, 1.0),
free_flux_hlr: bool = True,
free_dxdy: bool = True,
) -> Array:
prior = jnp.array(0.0)

if all_free:
if free_flux_hlr:
f1, f2 = flux_bds
prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1)

h1, h2 = hlr_bds
prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1)
prior += stats.uniform.logpdf(params["lhlr"], h1, h2 - h1)

# NOTE: hard-coded assumption that galaxy is in center-pixel within odd-size image.
# sigma_x in units of pixels.
prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)
if free_dxdy:
prior += stats.norm.logpdf(params["dx"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["dy"], loc=0.0, scale=sigma_x)

e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1)
prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e))
Expand All @@ -76,17 +90,33 @@ def logprior(
def loglikelihood(
params: dict[str, Array],
data: Array,
fixed_params: dict[str, Array],
*,
draw_fnc: Callable,
background: float,
free_flux: bool = True,
free_flux_hlr: bool = True,
free_dxdy: bool = True,
):
# NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments if fixed
_draw_params = {**{"g1": 0.0, "g2": 0.0}, **params}
_draw_params = {}

if free_dxdy:
_draw_params["x"] = params["dx"] + fixed_params["x"]
_draw_params["y"] = params["dy"] + fixed_params["y"]

else:
_draw_params["x"] = fixed_params["x"]
_draw_params["y"] = fixed_params["y"]

if free_flux_hlr:
_draw_params["f"] = 10 ** params["lf"]
_draw_params["hlr"] = 10 ** params["lhlr"]

# Convert log-flux to flux if provided
if free_flux:
_draw_params["f"] = 10 ** _draw_params.pop("lf")
else:
_draw_params["f"] = fixed_params["f"]
_draw_params["hlr"] = fixed_params["hlr"]

_draw_params["e1"] = params["e1"]
_draw_params["e2"] = params["e2"]

model = draw_fnc(**_draw_params)
likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
Expand All @@ -97,24 +127,11 @@ def logtarget(
params: dict[str, Array],
data: Array,
*,
fixed_params: dict[str, Array],
logprior_fnc: Callable,
loglikelihood_fnc: Callable,
):
return logprior_fnc(params) + loglikelihood_fnc(params, data)


def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation(
jnp.array([e1, e2]), jnp.array([g1, g2])
)
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime

return true_params # don't add g1,g2 back as we are not inferring those in interim posterior
return logprior_fnc(params) + loglikelihood_fnc(params, data, fixed_params)


def get_target_images_single(
Expand Down Expand Up @@ -156,30 +173,27 @@ def pipeline_interim_samples_one_galaxy(
rng_key: PRNGKeyArray,
true_params: dict[str, float],
target_image: Array,
fixed_draw_kwargs: dict,
fixed_params: dict[str, float],
*,
initialization_fnc: Callable,
draw_fnc: Callable,
logprior: Callable,
loglikelihood: Callable,
n_samples: int = 100,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = True,
background: float = 1.0,
free_flux: bool = True,
):
# Flux and HLR are fixed to truth and not inferred in this function.
k1, k2 = random.split(rng_key)

init_position = initialization_fnc(k1, true_params=true_params, data=target_image)
_draw_fnc = partial(draw_fnc, **fixed_draw_kwargs)
_loglikelihood = partial(
loglikelihood, draw_fnc=_draw_fnc, background=background, free_flux=free_flux
)

_logtarget = partial(
logtarget, logprior_fnc=logprior, loglikelihood_fnc=_loglikelihood
logtarget,
logprior_fnc=logprior,
loglikelihood_fnc=loglikelihood,
fixed_params=fixed_params,
)

_inference_fnc = partial(
Expand Down
59 changes: 54 additions & 5 deletions experiments/exp2/run_inference_galaxy_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import time
from functools import partial
from typing import Callable

import jax.numpy as jnp
import typer
from jax import jit, random, vmap
from jax import Array, jit, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd import DATA_DIR
from bpd.chains import run_sampling_nuts, run_warmup_nuts
Expand All @@ -16,11 +18,58 @@
from bpd.pipelines.image_samples import (
get_target_images,
get_true_params_from_galaxy_params,
loglikelihood,
logprior,
logtarget,
sample_target_galaxy_params_simple,
)
from bpd.prior import ellip_prior_e1e2


def logprior(
params: dict[str, Array],
*,
sigma_e: float,
sigma_x: float = 0.5, # pixels
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
) -> Array:
prior = jnp.array(0.0)

f1, f2 = flux_bds
prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1)

h1, h2 = hlr_bds
prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1)

prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)

e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1)
prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e))

return prior


def loglikelihood(
params: dict[str, Array],
data: Array,
*,
draw_fnc: Callable,
background: float,
):
_draw_params = {**params}
_draw_params["f"] = 10 ** _draw_params.pop("lf")
model = draw_fnc(**_draw_params)
likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
return jnp.sum(likelihood_pp)


def logtarget(
params: dict[str, Array],
data: Array,
*,
logprior_fnc: Callable,
loglikelihood_fnc: Callable,
):
return logprior_fnc(params) + loglikelihood_fnc(params, data)


def sample_prior(
Expand Down Expand Up @@ -67,7 +116,7 @@ def main(
pkey, nkey, ikey, rkey = random.split(rng_key, 4)

# directory structure
dirpath = DATA_DIR / "cache_chains" / f"test_image_sampling_{seed}"
dirpath = DATA_DIR / "cache_chains" / f"exp2_{seed}"
if not dirpath.exists():
dirpath.mkdir(exist_ok=True)
fpath = dirpath / f"chain_results_{seed}.npy"
Expand Down
Binary file removed experiments/exp3/figs/contours.pdf
Binary file not shown.
Binary file removed experiments/exp3/figs/scatter_shapes.pdf
Binary file not shown.
Binary file removed experiments/exp3/figs/traces.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion experiments/exp3/README.md → experiments/exp30/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Experiment 3
# Experiment 3.0

This folder contains scripts to reproduce results on fitting galaxy images in the low noise setting
and inferring shear from them. Details:
Expand Down
Binary file added experiments/exp30/figs/contours.pdf
Binary file not shown.
Binary file added experiments/exp30/figs/hists.pdf
Binary file not shown.
Binary file added experiments/exp30/figs/scatter_shapes.pdf
Binary file not shown.
Binary file added experiments/exp30/figs/traces.pdf
Binary file not shown.
File renamed without changes.
Loading

0 comments on commit 72bc24e

Please sign in to comment.