-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6a4b817
commit 606a7c1
Showing
11 changed files
with
446 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,4 +23,7 @@ git clone [email protected]:LSSTDESC/BPD.git | |
cd BPD | ||
pip install -e . | ||
pip install -e ".[dev]" | ||
|
||
# Might be necessary | ||
conda install -c nvidia cuda-nvcc | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Experiment 4 | ||
|
||
Same as experiment 3 but we let deviations from true centroid `(dx, dy)` be free parameters in | ||
the fit. | ||
|
||
|
||
For now we assume that the true and interim prior on the deviation is the same, so it does not get | ||
accounted for in the shear inference. |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
122 changes: 99 additions & 23 deletions
122
experiments/exp4/get_interim_samples.py
100644 → 100755
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,122 @@ | ||
#!/usr/bin/env python3 | ||
"""Check chains ran on a variety of galaxies with different SNR, initialization from the prior.""" | ||
|
||
import time | ||
from functools import partial | ||
|
||
import jax.numpy as jnp | ||
import typer | ||
from jax import jit, random, vmap | ||
from jax._src.prng import PRNGKeyArray | ||
|
||
from bpd import DATA_DIR | ||
from bpd.chains import run_sampling_nuts, run_warmup_nuts | ||
from bpd.draw import draw_gaussian | ||
from bpd.initialization import init_with_prior | ||
from bpd.initialization import init_with_truth | ||
from bpd.io import save_dataset | ||
from bpd.pipelines.image_samples import ( | ||
get_target_images, | ||
get_true_params_from_galaxy_params, | ||
loglikelihood, | ||
logprior, | ||
logtarget, | ||
pipeline_interim_samples_one_galaxy, | ||
sample_target_galaxy_params_simple, | ||
) | ||
|
||
|
||
def sample_prior( | ||
rng_key: PRNGKeyArray, | ||
*, | ||
mean_logflux: float = 6.0, | ||
sigma_logflux: float = 0.1, | ||
mean_loghlr: float = 1.0, | ||
sigma_loghlr: float = 0.05, | ||
shape_noise: float = 0.3, | ||
def main( | ||
seed: int, | ||
n_gals: int = 1000, # technically, in this file it means 'noise realizations' | ||
n_samples_per_gal: int = 100, | ||
g1: float = 0.02, | ||
g2: float = 0.0, | ||
) -> dict[str, float]: | ||
k1, k2, k3 = random.split(rng_key, 3) | ||
lf: float = 6.0, # ~ SNR = 1000 | ||
hlr: float = 1.0, | ||
shape_noise: float = 1e-4, | ||
sigma_e_int: float = 4e-2, | ||
slen: int = 53, | ||
fft_size: int = 256, | ||
background: float = 1.0, | ||
initial_step_size: float = 1e-3, | ||
): | ||
rng_key = random.key(seed) | ||
pkey, nkey, gkey = random.split(rng_key, 3) | ||
|
||
# directory structure | ||
dirpath = DATA_DIR / "cache_chains" / f"exp4_{seed}" | ||
if not dirpath.exists(): | ||
dirpath.mkdir(exist_ok=True) | ||
|
||
# galaxy parameters from prior | ||
pkeys = random.split(pkey, n_gals) | ||
_get_galaxy_params = partial( | ||
sample_target_galaxy_params_simple, g1=g1, g2=g2, shape_noise=shape_noise | ||
) | ||
galaxy_params = vmap(_get_galaxy_params)(pkeys) | ||
assert galaxy_params["x"].shape == (n_gals,) | ||
assert galaxy_params["e1"].shape == (n_gals,) | ||
assert "lf" not in galaxy_params | ||
extra_params = {"f": 10 ** jnp.full((n_gals,), lf), "hlr": jnp.full((n_gals,), hlr)} | ||
|
||
lf = random.normal(k1) * sigma_logflux + mean_logflux | ||
hlr = random.normal(k2) * sigma_loghlr + mean_loghlr | ||
# now get corresponding target images | ||
draw_params = {**galaxy_params, **extra_params} | ||
target_images = get_target_images( | ||
nkey, draw_params, background=background, slen=slen | ||
) | ||
assert target_images.shape == (n_gals, slen, slen) | ||
|
||
# interim samples are on 'sheared ellipticity' | ||
true_params = vmap(get_true_params_from_galaxy_params)(galaxy_params) | ||
|
||
other_params = sample_target_galaxy_params_simple( | ||
k3, shape_noise=shape_noise, g1=g1, g2=g2 | ||
# we pass in x,y as fixed parameters for drawing | ||
# and initialize the function with deviations (dx, dy) = (0, 0) | ||
x = true_params.pop("x") | ||
y = true_params.pop("y") | ||
true_params["dx"] = jnp.zeros_like(x) | ||
true_params["dy"] = jnp.zeros_like(y) | ||
extra_params = {"x": x, "y": y, **extra_params} | ||
|
||
# more setup | ||
_logprior = partial( | ||
logprior, sigma_e=sigma_e_int, free_flux_hlr=False, free_dxdy=True | ||
) | ||
|
||
return {"lf": lf, "hlr": hlr, **other_params} | ||
# prepare pipelines | ||
gkeys = random.split(gkey, n_gals) | ||
_draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size) | ||
pipe = partial( | ||
pipeline_interim_samples_one_galaxy, | ||
initialization_fnc=init_with_truth, | ||
draw_fnc=_draw_fnc, | ||
logprior=_logprior, | ||
n_samples=n_samples_per_gal, | ||
initial_step_size=initial_step_size, | ||
background=background, | ||
free_flux=False, | ||
) | ||
vpipe = vmap(jit(pipe)) | ||
|
||
# compilation on single target image | ||
_ = vpipe( | ||
gkeys[0, None], | ||
{k: v[0, None] for k, v in true_params.items()}, | ||
target_images[0, None], | ||
{k: v[0, None] for k, v in extra_params.items()}, | ||
) | ||
|
||
samples = vpipe(gkeys, true_params, target_images, extra_params) | ||
e_post = jnp.stack([samples["e1"], samples["e2"]], axis=-1) | ||
fpath = dirpath / f"e_post_{seed}.npz" | ||
|
||
save_dataset( | ||
{ | ||
"e_post": e_post, | ||
"true_g": jnp.array([g1, g2]), | ||
"dx": samples["dx"], | ||
"dy": samples["dy"], | ||
"sigma_e": shape_noise, | ||
"sigma_e_int": sigma_e_int, | ||
"e1": draw_params["e1"], | ||
"e2": draw_params["e2"], | ||
}, | ||
fpath, | ||
overwrite=True, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/bash | ||
export CUDA_VISIBLE_DEVICES="0" | ||
export JAX_ENABLE_X64="True" | ||
SEED="43" | ||
|
||
./get_interim_samples.py $SEED | ||
../../scripts/get_shear_from_interim_samples.py $SEED exp4_$SEED "e_post_${SEED}.npz" --overwrite |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import os | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "" | ||
os.environ["JAX_PLATFORMS"] = "cpu" | ||
os.environ["JAX_ENABLE_X64"] = "True" | ||
|
||
|
||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import typer | ||
from jax import Array | ||
from matplotlib.backends.backend_pdf import PdfPages | ||
|
||
from bpd import DATA_DIR | ||
from bpd.diagnostics import get_contour_plot | ||
from bpd.io import load_dataset | ||
|
||
|
||
def make_trace_plots(g_samples: Array) -> None: | ||
"""Make trace plots of g1, g2.""" | ||
fname = "figs/traces.pdf" | ||
with PdfPages(fname) as pdf: | ||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5)) | ||
g1 = g_samples[:, 0] | ||
g2 = g_samples[:, 1] | ||
|
||
ax1.plot(g1) | ||
ax2.plot(g2) | ||
|
||
pdf.savefig(fig) | ||
plt.close(fig) | ||
|
||
|
||
def make_contour_plots(g_samples: Array, n_examples=10) -> None: | ||
"""Make figure of contour plot on g1, g2.""" | ||
fname = "figs/contours.pdf" | ||
with PdfPages(fname) as pdf: | ||
truth = {"g1": 0.02, "g2": 0.0} | ||
g_dict = {"g1": g_samples[:, 0], "g2": g_samples[:, 1]} | ||
fig = get_contour_plot([g_dict], ["post"], truth) | ||
pdf.savefig(fig) | ||
plt.close(fig) | ||
|
||
|
||
def make_scatter_shape_plots(e_post: Array, n_examples: int = 10) -> None: | ||
"""Show example scatter plots of interim posterior ellipticitites.""" | ||
# make two types, assuming gaussianity and one not assuming gaussianity. | ||
fname = "figs/scatter_shapes.pdf" | ||
|
||
n_gals, _, _ = e_post.shape | ||
|
||
with PdfPages(fname) as pdf: | ||
# individual | ||
for _ in range(n_examples): | ||
idx = np.random.choice(np.arange(0, n_gals)) | ||
e1, e2 = e_post[idx, :, 0], e_post[idx, :, 1] | ||
fig, ax = plt.subplots(1, 1, figsize=(7, 7)) | ||
ax.scatter(e1, e2, marker="x") | ||
ax.set_title(f"Samples ellipticity index: {idx}") | ||
ax.set_xlabel("e1", fontsize=14) | ||
ax.set_ylabel("e2", fontsize=14) | ||
pdf.savefig(fig) | ||
plt.close(fig) | ||
|
||
# clusters | ||
n_clusters = 50 | ||
fig, ax = plt.subplots(1, 1, figsize=(7, 7)) | ||
ax.set_xlabel("e1", fontsize=14) | ||
ax.set_ylabel("e2", fontsize=14) | ||
fig.suptitle(f"{n_clusters} galaxies plotted") | ||
for _ in range(n_clusters): | ||
idx = np.random.choice(np.arange(0, n_gals)) | ||
e1, e2 = e_post[idx, :, 0], e_post[idx, :, 1] | ||
ax.scatter(e1, e2, marker="x") | ||
pdf.savefig(fig) | ||
plt.close(fig) | ||
|
||
|
||
def make_scatter_dxdy_plots(dx: Array, dy: Array, n_examples: int = 10) -> None: | ||
"""Show example scatter plots of interim posterior ellipticitites.""" | ||
# make two types, assuming gaussianity and one not assuming gaussianity. | ||
fname = "figs/scatter_dxdy.pdf" | ||
|
||
n_gals, _ = dx.shape | ||
|
||
with PdfPages(fname) as pdf: | ||
# individual | ||
for _ in range(n_examples): | ||
idx = np.random.choice(np.arange(0, n_gals)) | ||
dx1, dy1 = dx[idx, :], dy[idx, :] | ||
fig, ax = plt.subplots(1, 1, figsize=(7, 7)) | ||
ax.scatter(dx1, dy1, marker="x") | ||
ax.set_title(f"Samples ellipticity index: {idx}") | ||
ax.set_xlabel("dx", fontsize=14) | ||
ax.set_ylabel("dy", fontsize=14) | ||
pdf.savefig(fig) | ||
plt.close(fig) | ||
|
||
# clusters | ||
n_clusters = 50 | ||
fig, ax = plt.subplots(1, 1, figsize=(7, 7)) | ||
ax.set_xlabel("dx", fontsize=14) | ||
ax.set_ylabel("dy", fontsize=14) | ||
fig.suptitle(f"{n_clusters} galaxies plotted") | ||
for _ in range(n_clusters): | ||
idx = np.random.choice(np.arange(0, n_gals)) | ||
dx1, dy1 = dx[idx, :], dy[idx, :] | ||
ax.scatter(dx1, dy1, marker="x") | ||
pdf.savefig(fig) | ||
plt.close(fig) | ||
|
||
|
||
def make_hists(g_samples: Array, e1_samples: Array) -> None: | ||
"""Make histograms of g1 along with std and expected std.""" | ||
fname = "figs/hists.pdf" | ||
with PdfPages(fname) as pdf: | ||
fig, ax = plt.subplots(1, 1, figsize=(7, 7)) | ||
|
||
g1 = g_samples[:, 0] | ||
e1_std = e1_samples.std() | ||
g1_exp_std = e1_std / jnp.sqrt(len(e1_samples)) | ||
|
||
ax.hist(g1, bins=25, histtype="step") | ||
ax.axvline(g1.mean(), linestyle="--", color="k") | ||
ax.set_title(f"Std g1: {g1.std():.4g}; Expected g1 std: {g1_exp_std:.4g}") | ||
|
||
pdf.savefig(fig) | ||
plt.close(fig) | ||
|
||
|
||
def main(seed: int = 43): | ||
# load data | ||
pdir = DATA_DIR / "cache_chains" / f"exp4_{seed}" | ||
e_post_dict = load_dataset(pdir / f"e_post_{seed}.npz") | ||
e_post_samples = e_post_dict["e_post"] | ||
g_samples = jnp.load(pdir / f"g_samples_{seed}_{seed}.npy") | ||
|
||
e1_samples = e_post_dict["e1"] | ||
dx = e_post_dict["dx"] | ||
dy = e_post_dict["dy"] | ||
|
||
# make plots | ||
make_scatter_shape_plots(e_post_samples) | ||
make_scatter_dxdy_plots(dx, dy) | ||
make_trace_plots(g_samples) | ||
make_contour_plots(g_samples) | ||
make_hists(g_samples, e1_samples) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) |
Oops, something went wrong.