Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft of various fixes and improvements (e.g. GB prior) #60

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@
# the priors are callables for now on only ellipticities
# the interim_prior should have been used when obtaining e_obs from the chain (i.e. for now same sigma)
_, _, _ = e_post.shape # (N, K, 2)
_prior = partial(prior, sigma=sigma_e)

Check failure on line 31 in bpd/likelihood.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (RUF052)

bpd/likelihood.py:31:5: RUF052 Local dummy variable `_prior` is accessed

Check failure on line 31 in bpd/likelihood.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (RUF052)

bpd/likelihood.py:31:5: RUF052 Local dummy variable `_prior` is accessed

Check failure on line 31 in bpd/likelihood.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (RUF052)

bpd/likelihood.py:31:5: RUF052 Local dummy variable `_prior` is accessed

# denom
e_post_mag = norm(e_post, axis=-1)
denom = interim_prior(e_post_mag) # (N, K), can ignore angle in prior as uniform

# for num, use trick
# p(w_n' | g, alpha ) = p(w_n' \cross^{-1} g | alpha ) = p(w_n | alpha) * |jac(w_n / w_n')|

# for num, need to include Jacobian of shear transoformation
# shape = (N, K, 2)
grad1 = _grad_fnc1(e_post, g)
grad2 = _grad_fnc2(e_post, g)
Expand Down
35 changes: 20 additions & 15 deletions bpd/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,40 @@
from jax import Array, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.numpy.linalg import norm
from jax.scipy.special import erf
from jax.typing import ArrayLike


def ellip_mag_prior(e: ArrayLike, sigma: float) -> ArrayLike:
"""Unnormalized Prior for the magnitude of the ellipticity, domain is (0, 1)
def ellip_mag_prior(e_mag: ArrayLike, sigma: float) -> ArrayLike:
"""Prior for the magnitude of the ellipticity, domain is (0, 1)

This distribution is taken from Gary's 2013 paper on Bayesian shear inference.
The additional factor on the truncated Gaussian guarantees differentiability
at e = 0 and e = 1.

Gary uses 0.3 as a default level of shape noise.
Importantly, the paper did not include an additional factor of |e| that is needed
make this the correct expression for a Gaussian in `polar` coordinates. We include
this factor in this equation. This blog post is helpful: https://andrewcharlesjones.github.io/journal/rayleigh.html

The additional factor of (1-e^2)^2 introduced by Gary guarantees differentiability
at e = 0 and e = 1.
"""

# norm from Mathematica
_norm1 = jnp.sqrt(jnp.pi / 2) * (3 * sigma**5 - 2 * sigma**3 + sigma)
_norm1 *= erf(1 / (jnp.sqrt(2) * sigma))
_norm2 = jnp.exp(-1 / (2 * sigma**2)) * (sigma**2 - 3 * sigma**4)
_norm = _norm1 + _norm2
return (1 - e**2) ** 2 * jnp.exp(-(e**2) / (2 * sigma**2)) / _norm
_norm = -4 * sigma**4 + sigma**2 + 8 * sigma**6 * (1 - jnp.exp(-1 / (2 * sigma**2)))
return (1 - e_mag**2) ** 2 * e_mag * jnp.exp(-(e_mag**2) / (2 * sigma**2)) / _norm


def sample_mag_ellip_prior(
rng_key: PRNGKeyArray, sigma: float, n: int = 1, n_bins: int = 1_000_000
):
"""Sample n points from Gary's ellipticity magnitude prior."""
# this part could be cached
e_array = jnp.linspace(0, 1, n_bins)
p_array = ellip_mag_prior(e_array, sigma=sigma)
emag_array = jnp.linspace(0, 1, n_bins)
p_array = ellip_mag_prior(emag_array, sigma=sigma)
p_array /= p_array.sum()
return random.choice(rng_key, e_array, shape=(n,), p=p_array)
return random.choice(rng_key, emag_array, shape=(n,), p=p_array)


def sample_ellip_prior(rng_key: PRNGKeyArray, sigma: float, n: int = 1):
"""Sample n ellipticities isotropic components with Gary's prior from magnitude."""
"""Sample n ellipticities isotropic components with Gary's prior for magnitude."""
key1, key2 = random.split(rng_key, 2)
e_mag = sample_mag_ellip_prior(key1, sigma=sigma, n=n)
e_phi = random.uniform(key2, shape=(n,), minval=0, maxval=jnp.pi)
Expand All @@ -45,6 +44,12 @@ def sample_ellip_prior(rng_key: PRNGKeyArray, sigma: float, n: int = 1):
return jnp.stack((e1, e2), axis=1)


def ellip_prior_e1e2(e: Array, sigma: float) -> ArrayLike:
"""Prior on e1, e2 using Gary's prior for magnitude. Includes Jacobian factor."""
e_mag = jnp.sqrt(e[..., 0] ** 2 + e[..., 1] ** 2)
return ellip_mag_prior(e_mag, sigma=sigma) / e_mag # jacobian factor


def scalar_shear_transformation(e: Array, g: Array):
"""Transform elliptiticies by a fixed shear (scalar version).

Expand Down
Binary file modified experiments/exp2/figs/contours.pdf
Binary file not shown.
Binary file modified experiments/exp2/figs/convergence_hist.pdf
Binary file not shown.
8 changes: 4 additions & 4 deletions experiments/exp2/figs/outliers.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Number of R-hat outliers for e1: 1
Number of R-hat outliers for e2: 2
Number of R-hat outliers for e1: 0
Number of R-hat outliers for e2: 0
Number of R-hat outliers for hlr: 0
Number of R-hat outliers for lf: 0
Number of R-hat outliers for x: 2
Number of R-hat outliers for y: 3
Number of R-hat outliers for x: 0
Number of R-hat outliers for y: 0
Binary file modified experiments/exp2/figs/timing.pdf
Binary file not shown.
Binary file modified experiments/exp2/figs/traces.pdf
Binary file not shown.
Binary file modified experiments/exp2/figs/traces_adapt.pdf
Binary file not shown.
Binary file removed experiments/exp2/figs/traces_out.pdf
Binary file not shown.
Binary file added experiments/exp2/figs/tuned_hists.pdf
Binary file not shown.
39 changes: 37 additions & 2 deletions experiments/exp2/make_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,19 @@ def make_convergence_histograms(samples_dict: dict[str, Array]) -> None:

with PdfPages(fname) as pdf:
for p in samples_dict:
rhat_p = rhats[p]
ess_p = ess[p]
rhat_p = np.array(rhats[p])
ess_p = np.array(ess[p])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 6))
fig.suptitle(p, fontsize=18)

ax1.hist(rhat_p, bins=25, range=(0.98, 1.1))
ax2.hist(ess_p, bins=25)
ax2.axvline(ess_p.mean(), linestyle="--", color="k", label="mean")

ax1.set_xlabel("R-hat")
ax2.set_ylabel("ESS")
ax2.legend()

pdf.savefig(fig)
plt.close(fig)
Expand Down Expand Up @@ -176,6 +178,34 @@ def make_timing_plots(results_dict: dict) -> None:
plt.close(fig)


def make_adaptation_hists(tuned_params: dict, pnames: dict):
fname = "figs/tuned_hists.pdf"

step_sizes = tuned_params["step_size"]
imm = tuned_params["inverse_mass_matrix"]

with PdfPages(fname) as pdf:
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.hist(step_sizes.flatten(), bins=25)
ax.axvline(step_sizes.flatten().mean(), linestyle="--", color="k", label="mean")
ax.set_xlabel("Step sizes")
ax.legend()

pdf.savefig(fig)
plt.close(fig)

for ii, p in enumerate(pnames):
diag_elems = imm[:, :, ii].flatten()
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.hist(diag_elems, bins=25)
ax.axvline(diag_elems.mean(), linestyle="--", color="k", label="mean")
ax.set_xlabel(f"Diag Mass Matrix for {p}")
ax.legend()

pdf.savefig(fig)
plt.close(fig)


def main():
np.random.seed(42)

Expand All @@ -187,6 +217,9 @@ def main():
max_n_gal = max(results.keys())
samples = results[max_n_gal]["samples"]
truth = results[max_n_gal]["truth"]
tuned_params = results[max_n_gal]["tuned_params"]

param_names = list(samples.keys())

# make plots
make_trace_plots(samples, truth, fpath="figs/traces.pdf")
Expand All @@ -205,6 +238,8 @@ def main():
if Path("figs/traces_out.pdf").exists():
os.remove("figs/traces_out.pdf")

make_adaptation_hists(tuned_params, param_names)


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion experiments/exp2/run_inference_galaxy_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def main(
if not dirpath.exists():
dirpath.mkdir(exist_ok=True)
fpath = dirpath / f"chain_results_{seed}.npy"
assert not fpath.exists()

# setup target density
draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size)
Expand Down Expand Up @@ -100,7 +101,7 @@ def main(
_run_sampling = vmap(vmap(jjit(_run_sampling1), in_axes=(0, 0, 0, None)))

results = {}
for n_gals in (1, 1, 5, 10, 20, 25, 50, 100, 250, 500): # repeat 1 == compilation
for n_gals in (1, 1, 5, 10, 20, 25, 50, 100, 250, 300): # repeat 1 == compilation
print("n_gals:", n_gals)

# generate data and parameters
Expand Down Expand Up @@ -146,6 +147,7 @@ def main(
results[n_gals]["samples"] = samples
results[n_gals]["truth"] = true_params
results[n_gals]["adapt_info"] = adapt_info
results[n_gals]["tuned_params"] = tuned_params

jnp.save(fpath, results)

Expand Down
1 change: 1 addition & 0 deletions experiments/exp2_optimization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Experiment 2 (Optimization)
Binary file added experiments/exp2_optimization/figs/contours.pdf
Binary file not shown.
Binary file not shown.
6 changes: 6 additions & 0 deletions experiments/exp2_optimization/figs/outliers.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Number of R-hat outliers for e1: 0
Number of R-hat outliers for e2: 0
Number of R-hat outliers for hlr: 0
Number of R-hat outliers for lf: 0
Number of R-hat outliers for x: 4
Number of R-hat outliers for y: 5
Binary file added experiments/exp2_optimization/figs/timing.pdf
Binary file not shown.
Binary file added experiments/exp2_optimization/figs/traces.pdf
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 5 additions & 0 deletions experiments/exp2_optimization/get_posteriors.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES="0"
export JAX_ENABLE_X64="True"

./run_inference_galaxy_images.py 42
Loading
Loading