-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Draft implementation of shape noise cancellation with jackknife proce…
…dure (#72) * move notebooks around * draft * initial draft seems correct * attempt to generalize jackknife procedure
- Loading branch information
1 parent
bf46bb9
commit d5aaa0f
Showing
6 changed files
with
642 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from math import ceil | ||
from typing import Callable | ||
|
||
import jax.numpy as jnp | ||
from jax import Array, random | ||
from tqdm import tqdm | ||
|
||
|
||
def run_jackknife_shear_pipeline( | ||
rng_key, | ||
init_g: Array, | ||
post_params_pos: dict, | ||
post_params_neg: dict, | ||
shear_pipeline: Callable, | ||
n_jacks: int = 10, | ||
disable_bar: bool = True, | ||
): | ||
"""Use Jackknife to estimate the mean and std of the shear posterior. | ||
Args: | ||
rng_key: Random jax key. | ||
init_g: Initial value for shear `g`. | ||
post_params_pos: Interim posterior galaxy parameters estimated using positive shear. | ||
post_params_neg: Interim posterior galaxy parameters estimated using negative shear, | ||
and otherwise same conditions and random seed as `post_params_pos`. | ||
shear_pipeline: Function that outputs shear posterior samples from `post_params` with all | ||
keyword arguments pre-specified. | ||
n_jacks: Number of jackknife batches. | ||
Returns: | ||
Jackknife | ||
""" | ||
N, _ = post_params_pos["e1"].shape # N = n_gals, K = n_samples_per_gal | ||
batch_size = ceil(N / n_jacks) | ||
|
||
g_best_list = [] | ||
keys = random.split(rng_key, n_jacks) | ||
|
||
for ii in tqdm(range(n_jacks), desc="Jackknife #", disable=disable_bar): | ||
k_ii = keys[ii] | ||
start, end = ii * batch_size, (ii + 1) * batch_size | ||
|
||
_params_jack_pos = { | ||
k: jnp.concatenate([v[:start], v[end:]]) for k, v in post_params_pos.items() | ||
} | ||
_params_jack_neg = { | ||
k: jnp.concatenate([v[:start], v[end:]]) for k, v in post_params_neg.items() | ||
} | ||
|
||
g_pos_ii = shear_pipeline(k_ii, _params_jack_pos, init_g) | ||
g_neg_ii = shear_pipeline(k_ii, _params_jack_neg, -init_g) | ||
g_best_ii = (g_pos_ii - g_neg_ii) * 0.5 | ||
g_best_mean_ii = g_best_ii.mean(axis=0) | ||
|
||
g_best_list.append(g_best_mean_ii) | ||
|
||
g_best_means = jnp.array(g_best_list) | ||
return g_best_means |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.