diff --git a/bpd/jackknife.py b/bpd/jackknife.py new file mode 100644 index 0000000..c6637c4 --- /dev/null +++ b/bpd/jackknife.py @@ -0,0 +1,59 @@ +from math import ceil +from typing import Callable + +import jax.numpy as jnp +from jax import Array, random +from tqdm import tqdm + + +def run_jackknife_shear_pipeline( + rng_key, + init_g: Array, + post_params_pos: dict, + post_params_neg: dict, + shear_pipeline: Callable, + n_jacks: int = 10, + disable_bar: bool = True, +): + """Use Jackknife to estimate the mean and std of the shear posterior. + + Args: + rng_key: Random jax key. + init_g: Initial value for shear `g`. + post_params_pos: Interim posterior galaxy parameters estimated using positive shear. + post_params_neg: Interim posterior galaxy parameters estimated using negative shear, + and otherwise same conditions and random seed as `post_params_pos`. + shear_pipeline: Function that outputs shear posterior samples from `post_params` with all + keyword arguments pre-specified. + n_jacks: Number of jackknife batches. + + Returns: + Jackknife + + """ + N, _ = post_params_pos["e1"].shape # N = n_gals, K = n_samples_per_gal + batch_size = ceil(N / n_jacks) + + g_best_list = [] + keys = random.split(rng_key, n_jacks) + + for ii in tqdm(range(n_jacks), desc="Jackknife #", disable=disable_bar): + k_ii = keys[ii] + start, end = ii * batch_size, (ii + 1) * batch_size + + _params_jack_pos = { + k: jnp.concatenate([v[:start], v[end:]]) for k, v in post_params_pos.items() + } + _params_jack_neg = { + k: jnp.concatenate([v[:start], v[end:]]) for k, v in post_params_neg.items() + } + + g_pos_ii = shear_pipeline(k_ii, _params_jack_pos, init_g) + g_neg_ii = shear_pipeline(k_ii, _params_jack_neg, -init_g) + g_best_ii = (g_pos_ii - g_neg_ii) * 0.5 + g_best_mean_ii = g_best_ii.mean(axis=0) + + g_best_list.append(g_best_mean_ii) + + g_best_means = jnp.array(g_best_list) + return g_best_means diff --git a/notebooks/check-g-mag-random.ipynb b/notebooks/old/check-g-mag-random.ipynb similarity index 100% rename from notebooks/check-g-mag-random.ipynb rename to notebooks/old/check-g-mag-random.ipynb diff --git a/notebooks/check-likelihood-toy1.ipynb b/notebooks/old/check-likelihood-toy1.ipynb similarity index 100% rename from notebooks/check-likelihood-toy1.ipynb rename to notebooks/old/check-likelihood-toy1.ipynb diff --git a/notebooks/check-new-prior1.ipynb b/notebooks/old/check-new-prior1.ipynb similarity index 100% rename from notebooks/check-new-prior1.ipynb rename to notebooks/old/check-new-prior1.ipynb diff --git a/notebooks/normalized-prior-check1.ipynb b/notebooks/old/normalized-prior-check1.ipynb similarity index 100% rename from notebooks/normalized-prior-check1.ipynb rename to notebooks/old/normalized-prior-check1.ipynb diff --git a/notebooks/shape-noise-cancellation-draft1.ipynb b/notebooks/shape-noise-cancellation-draft1.ipynb new file mode 100644 index 0000000..f822632 --- /dev/null +++ b/notebooks/shape-noise-cancellation-draft1.ipynb @@ -0,0 +1,583 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "os.environ[\"JAX_ENABLE_X64\"] = \"True\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from math import ceil\n", + "from tqdm import tqdm\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import jax.numpy as jnp\n", + "from jax import random\n", + "\n", + "\n", + "from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples\n", + "from bpd.pipelines.shear_inference import pipeline_shear_inference_ellipticities" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "seed = 42\n", + "key = random.key(seed)\n", + "\n", + "g1 = 0.02\n", + "g2 = 0.0\n", + "true_g = jnp.array([g1, g2])\n", + "\n", + "sigma_e = 1e-3\n", + "sigma_e_int = 4e-2\n", + "sigma_m = 1e-5\n", + "n_gals = 1000\n", + "n_samples_per_gal = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "k1, k2 = random.split(key)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# positive shear\n", + "e_post_pos, _, _ = pipeline_toy_ellips_samples(k1, g1=g1, g2=g2, sigma_e=sigma_e, sigma_e_int=sigma_e_int, sigma_m=sigma_m, n_gals=n_gals, n_samples_per_gal=n_samples_per_gal)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# negative shear (same key!)\n", + "e_post_neg, _, _ = pipeline_toy_ellips_samples(k1, g1=-g1, g2=-g2, sigma_e=sigma_e, sigma_e_int=sigma_e_int, sigma_m=sigma_m, n_gals=n_gals, n_samples_per_gal=n_samples_per_gal)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(1000, 50, 2)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e_post_pos.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "g_pos_samples = pipeline_shear_inference_ellipticities(k2, e_post_pos, true_g, sigma_e=sigma_e, sigma_e_int=sigma_e_int, n_samples=1000, initial_step_size=1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "g_neg_samples = pipeline_shear_inference_ellipticities(k2, e_post_neg, -true_g, sigma_e=sigma_e, sigma_e_int=sigma_e_int, n_samples=1000, initial_step_size=1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1000, 2)" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_pos_samples.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(3.37190099e-05, dtype=float64)" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_pos_samples[:, 0].std()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 6., 4., 9., 11., 17., 34., 42., 58., 73., 70., 72., 91., 80.,\n", + " 82., 73., 74., 61., 48., 37., 23., 16., 9., 5., 4., 1.]),\n", + " array([-0.0201118 , -0.02010414, -0.02009648, -0.02008883, -0.02008117,\n", + " -0.02007351, -0.02006586, -0.0200582 , -0.02005054, -0.02004288,\n", + " -0.02003523, -0.02002757, -0.02001991, -0.02001226, -0.0200046 ,\n", + " -0.01999694, -0.01998928, -0.01998163, -0.01997397, -0.01996631,\n", + " -0.01995866, -0.019951 , -0.01994334, -0.01993568, -0.01992803,\n", + " -0.01992037]),\n", + " )" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# plt.hist(g_pos_samples[:, 0], bins=25)\n", + "plt.figure(figsize=(7,7))\n", + "plt.hist(g_neg_samples[:, 0], bins=25)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "g_best_samples = (g_pos_samples - g_neg_samples) * 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 1., 1., 2., 3., 2., 1., 6., 3., 8., 7., 24.,\n", + " 33., 58., 129., 210., 292., 131., 45., 17., 12., 4., 4.,\n", + " 3., 2., 2.]),\n", + " array([0.01999981, 0.01999981, 0.01999981, 0.01999981, 0.01999981,\n", + " 0.01999981, 0.01999981, 0.01999981, 0.01999981, 0.01999981,\n", + " 0.01999981, 0.01999981, 0.01999981, 0.01999981, 0.01999981,\n", + " 0.01999981, 0.01999981, 0.01999981, 0.01999981, 0.01999981,\n", + " 0.01999981, 0.01999981, 0.01999981, 0.01999981, 0.01999981,\n", + " 0.01999981]),\n", + " )" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(g_best_samples[:, 0], bins=25)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array(0.01999981, dtype=float64), Array(0.01998104, dtype=float64))" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_best_samples[:, 0].mean(), g_pos_samples[:, 0].mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array(3.50319931e-10, dtype=float64), Array(3.3719104e-05, dtype=float64))" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_best_samples[:, 0].std(), g_pos_samples[:, 0].std()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Jackknife for error on the mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# let's start by splitting into 10 batches\n", + "\n", + "# we only need to repeat the second step of the inference" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(1000, 50, 2)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e_post_pos.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "1\n", + "2\n", + "3\n", + "4\n", + "5\n", + "6\n", + "7\n", + "8\n", + "9\n" + ] + } + ], + "source": [ + "from math import ceil\n", + "n_batches = 10\n", + "batch_size = ceil(n_gals / n_batches)\n", + "\n", + "# g_pos_list = [] \n", + "# g_neg_list = []\n", + "g_best_list = [] \n", + "\n", + "keys = random.split(k2, n_batches)\n", + "\n", + "for ii in range(n_batches): \n", + " print(ii)\n", + "\n", + " k_ii = keys[ii]\n", + " start, end = ii * batch_size, (ii+1) * batch_size\n", + " e_pos1 = jnp.concatenate([e_post_pos[:start], e_post_pos[end:]], axis=0)\n", + " e_neg1 = jnp.concatenate([e_post_neg[:start], e_post_neg[end:]], axis=0)\n", + "\n", + " g_pos1 = pipeline_shear_inference_ellipticities(k_ii, e_post_pos, true_g, sigma_e=sigma_e, sigma_e_int=sigma_e_int, n_samples=1000, initial_step_size=1e-3)\n", + "\n", + " g_neg1 = pipeline_shear_inference_ellipticities(k_ii, e_post_neg, -true_g, sigma_e=sigma_e, sigma_e_int=sigma_e_int, n_samples=1000, initial_step_size=1e-3)\n", + "\n", + " g_best1 = (g_pos1 - g_neg1) * 0.5\n", + " g_best_mean1 = g_best1.mean(axis=0)\n", + "\n", + " g_best_list.append(g_best_mean1)\n", + "\n", + " # g_pos_list.append(g_pos1)\n", + " # g_neg_list.append(g_neg1)\n", + "\n", + "# g_pos_jack = jnp.stack(g_pos_list, axis=0)\n", + "# g_neg_jack = jnp.stack(g_neg_list, axis=0)\n", + "g_best_means = jnp.array(g_best_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(10, 2)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_best_means.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.01999975, 0.01999981, 0.01999981, 0.01999938, 0.01999988,\n", + " 0.01999977, 0.01999979, 0.01999971, 0.01999985, 0.01999981], dtype=float64)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_best_means[:, 0]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(0.01999976, dtype=float64)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_best_means[:, 0].mean() #g1" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(3.96908294e-07, dtype=float64)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# error on the mean (jackknife expression)\n", + "jnp.sqrt(g_best_means[:, 0].var() * (n_batches-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(-2.44035424e-07, dtype=float64)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_best_means[:, 0].mean() - 0.02" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "101" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g_" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(900, 50, 2)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.concatenate([e_post_pos[:100], e_post_neg[200:]], axis=0).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bpd_gpu3", + "language": "python", + "name": "bpd_gpu3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}