Skip to content

Commit

Permalink
add exp with centroid free
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Dec 31, 2024
1 parent 6a4b817 commit 606a7c1
Show file tree
Hide file tree
Showing 11 changed files with 446 additions and 23 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ git clone [email protected]:LSSTDESC/BPD.git
cd BPD
pip install -e .
pip install -e ".[dev]"

# Might be necessary
conda install -c nvidia cuda-nvcc
```
8 changes: 8 additions & 0 deletions experiments/exp4/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Experiment 4

Same as experiment 3 but we let deviations from true centroid `(dx, dy)` be free parameters in
the fit.


For now we assume that the true and interim prior on the deviation is the same, so it does not get
accounted for in the shear inference.
Binary file added experiments/exp4/figs/contours.pdf
Binary file not shown.
Binary file added experiments/exp4/figs/hists.pdf
Binary file not shown.
Binary file added experiments/exp4/figs/scatter_dxdy.pdf
Binary file not shown.
Binary file added experiments/exp4/figs/scatter_shapes.pdf
Binary file not shown.
Binary file added experiments/exp4/figs/traces.pdf
Binary file not shown.
122 changes: 99 additions & 23 deletions experiments/exp4/get_interim_samples.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,46 +1,122 @@
#!/usr/bin/env python3
"""Check chains ran on a variety of galaxies with different SNR, initialization from the prior."""

import time
from functools import partial

import jax.numpy as jnp
import typer
from jax import jit, random, vmap
from jax._src.prng import PRNGKeyArray

from bpd import DATA_DIR
from bpd.chains import run_sampling_nuts, run_warmup_nuts
from bpd.draw import draw_gaussian
from bpd.initialization import init_with_prior
from bpd.initialization import init_with_truth
from bpd.io import save_dataset
from bpd.pipelines.image_samples import (
get_target_images,
get_true_params_from_galaxy_params,
loglikelihood,
logprior,
logtarget,
pipeline_interim_samples_one_galaxy,
sample_target_galaxy_params_simple,
)


def sample_prior(
rng_key: PRNGKeyArray,
*,
mean_logflux: float = 6.0,
sigma_logflux: float = 0.1,
mean_loghlr: float = 1.0,
sigma_loghlr: float = 0.05,
shape_noise: float = 0.3,
def main(
seed: int,
n_gals: int = 1000, # technically, in this file it means 'noise realizations'
n_samples_per_gal: int = 100,
g1: float = 0.02,
g2: float = 0.0,
) -> dict[str, float]:
k1, k2, k3 = random.split(rng_key, 3)
lf: float = 6.0, # ~ SNR = 1000
hlr: float = 1.0,
shape_noise: float = 1e-4,
sigma_e_int: float = 4e-2,
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
initial_step_size: float = 1e-3,
):
rng_key = random.key(seed)
pkey, nkey, gkey = random.split(rng_key, 3)

# directory structure
dirpath = DATA_DIR / "cache_chains" / f"exp4_{seed}"
if not dirpath.exists():
dirpath.mkdir(exist_ok=True)

# galaxy parameters from prior
pkeys = random.split(pkey, n_gals)
_get_galaxy_params = partial(
sample_target_galaxy_params_simple, g1=g1, g2=g2, shape_noise=shape_noise
)
galaxy_params = vmap(_get_galaxy_params)(pkeys)
assert galaxy_params["x"].shape == (n_gals,)
assert galaxy_params["e1"].shape == (n_gals,)
assert "lf" not in galaxy_params
extra_params = {"f": 10 ** jnp.full((n_gals,), lf), "hlr": jnp.full((n_gals,), hlr)}

lf = random.normal(k1) * sigma_logflux + mean_logflux
hlr = random.normal(k2) * sigma_loghlr + mean_loghlr
# now get corresponding target images
draw_params = {**galaxy_params, **extra_params}
target_images = get_target_images(
nkey, draw_params, background=background, slen=slen
)
assert target_images.shape == (n_gals, slen, slen)

# interim samples are on 'sheared ellipticity'
true_params = vmap(get_true_params_from_galaxy_params)(galaxy_params)

other_params = sample_target_galaxy_params_simple(
k3, shape_noise=shape_noise, g1=g1, g2=g2
# we pass in x,y as fixed parameters for drawing
# and initialize the function with deviations (dx, dy) = (0, 0)
x = true_params.pop("x")
y = true_params.pop("y")
true_params["dx"] = jnp.zeros_like(x)
true_params["dy"] = jnp.zeros_like(y)
extra_params = {"x": x, "y": y, **extra_params}

# more setup
_logprior = partial(
logprior, sigma_e=sigma_e_int, free_flux_hlr=False, free_dxdy=True
)

return {"lf": lf, "hlr": hlr, **other_params}
# prepare pipelines
gkeys = random.split(gkey, n_gals)
_draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size)
pipe = partial(
pipeline_interim_samples_one_galaxy,
initialization_fnc=init_with_truth,
draw_fnc=_draw_fnc,
logprior=_logprior,
n_samples=n_samples_per_gal,
initial_step_size=initial_step_size,
background=background,
free_flux=False,
)
vpipe = vmap(jit(pipe))

# compilation on single target image
_ = vpipe(
gkeys[0, None],
{k: v[0, None] for k, v in true_params.items()},
target_images[0, None],
{k: v[0, None] for k, v in extra_params.items()},
)

samples = vpipe(gkeys, true_params, target_images, extra_params)
e_post = jnp.stack([samples["e1"], samples["e2"]], axis=-1)
fpath = dirpath / f"e_post_{seed}.npz"

save_dataset(
{
"e_post": e_post,
"true_g": jnp.array([g1, g2]),
"dx": samples["dx"],
"dy": samples["dy"],
"sigma_e": shape_noise,
"sigma_e_int": sigma_e_int,
"e1": draw_params["e1"],
"e2": draw_params["e2"],
},
fpath,
overwrite=True,
)


if __name__ == "__main__":
typer.run(main)
7 changes: 7 additions & 0 deletions experiments/exp4/get_posteriors.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES="0"
export JAX_ENABLE_X64="True"
SEED="43"

./get_interim_samples.py $SEED
../../scripts/get_shear_from_interim_samples.py $SEED exp4_$SEED "e_post_${SEED}.npz" --overwrite
154 changes: 154 additions & 0 deletions experiments/exp4/make_figures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#!/usr/bin/env python3

import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "True"


import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import typer
from jax import Array
from matplotlib.backends.backend_pdf import PdfPages

from bpd import DATA_DIR
from bpd.diagnostics import get_contour_plot
from bpd.io import load_dataset


def make_trace_plots(g_samples: Array) -> None:
"""Make trace plots of g1, g2."""
fname = "figs/traces.pdf"
with PdfPages(fname) as pdf:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5))
g1 = g_samples[:, 0]
g2 = g_samples[:, 1]

ax1.plot(g1)
ax2.plot(g2)

pdf.savefig(fig)
plt.close(fig)


def make_contour_plots(g_samples: Array, n_examples=10) -> None:
"""Make figure of contour plot on g1, g2."""
fname = "figs/contours.pdf"
with PdfPages(fname) as pdf:
truth = {"g1": 0.02, "g2": 0.0}
g_dict = {"g1": g_samples[:, 0], "g2": g_samples[:, 1]}
fig = get_contour_plot([g_dict], ["post"], truth)
pdf.savefig(fig)
plt.close(fig)


def make_scatter_shape_plots(e_post: Array, n_examples: int = 10) -> None:
"""Show example scatter plots of interim posterior ellipticitites."""
# make two types, assuming gaussianity and one not assuming gaussianity.
fname = "figs/scatter_shapes.pdf"

n_gals, _, _ = e_post.shape

with PdfPages(fname) as pdf:
# individual
for _ in range(n_examples):
idx = np.random.choice(np.arange(0, n_gals))
e1, e2 = e_post[idx, :, 0], e_post[idx, :, 1]
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.scatter(e1, e2, marker="x")
ax.set_title(f"Samples ellipticity index: {idx}")
ax.set_xlabel("e1", fontsize=14)
ax.set_ylabel("e2", fontsize=14)
pdf.savefig(fig)
plt.close(fig)

# clusters
n_clusters = 50
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.set_xlabel("e1", fontsize=14)
ax.set_ylabel("e2", fontsize=14)
fig.suptitle(f"{n_clusters} galaxies plotted")
for _ in range(n_clusters):
idx = np.random.choice(np.arange(0, n_gals))
e1, e2 = e_post[idx, :, 0], e_post[idx, :, 1]
ax.scatter(e1, e2, marker="x")
pdf.savefig(fig)
plt.close(fig)


def make_scatter_dxdy_plots(dx: Array, dy: Array, n_examples: int = 10) -> None:
"""Show example scatter plots of interim posterior ellipticitites."""
# make two types, assuming gaussianity and one not assuming gaussianity.
fname = "figs/scatter_dxdy.pdf"

n_gals, _ = dx.shape

with PdfPages(fname) as pdf:
# individual
for _ in range(n_examples):
idx = np.random.choice(np.arange(0, n_gals))
dx1, dy1 = dx[idx, :], dy[idx, :]
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.scatter(dx1, dy1, marker="x")
ax.set_title(f"Samples ellipticity index: {idx}")
ax.set_xlabel("dx", fontsize=14)
ax.set_ylabel("dy", fontsize=14)
pdf.savefig(fig)
plt.close(fig)

# clusters
n_clusters = 50
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.set_xlabel("dx", fontsize=14)
ax.set_ylabel("dy", fontsize=14)
fig.suptitle(f"{n_clusters} galaxies plotted")
for _ in range(n_clusters):
idx = np.random.choice(np.arange(0, n_gals))
dx1, dy1 = dx[idx, :], dy[idx, :]
ax.scatter(dx1, dy1, marker="x")
pdf.savefig(fig)
plt.close(fig)


def make_hists(g_samples: Array, e1_samples: Array) -> None:
"""Make histograms of g1 along with std and expected std."""
fname = "figs/hists.pdf"
with PdfPages(fname) as pdf:
fig, ax = plt.subplots(1, 1, figsize=(7, 7))

g1 = g_samples[:, 0]
e1_std = e1_samples.std()
g1_exp_std = e1_std / jnp.sqrt(len(e1_samples))

ax.hist(g1, bins=25, histtype="step")
ax.axvline(g1.mean(), linestyle="--", color="k")
ax.set_title(f"Std g1: {g1.std():.4g}; Expected g1 std: {g1_exp_std:.4g}")

pdf.savefig(fig)
plt.close(fig)


def main(seed: int = 43):
# load data
pdir = DATA_DIR / "cache_chains" / f"exp4_{seed}"
e_post_dict = load_dataset(pdir / f"e_post_{seed}.npz")
e_post_samples = e_post_dict["e_post"]
g_samples = jnp.load(pdir / f"g_samples_{seed}_{seed}.npy")

e1_samples = e_post_dict["e1"]
dx = e_post_dict["dx"]
dy = e_post_dict["dy"]

# make plots
make_scatter_shape_plots(e_post_samples)
make_scatter_dxdy_plots(dx, dy)
make_trace_plots(g_samples)
make_contour_plots(g_samples)
make_hists(g_samples, e1_samples)


if __name__ == "__main__":
typer.run(main)
Loading

0 comments on commit 606a7c1

Please sign in to comment.