Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

light refactoring #50

Merged
merged 9 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Bayesian Pixel Domain shear estimation based on automatically differentiable cel

This repository contains functions to run HMC (Hamiltonian Monte Carlo) using [JAX-Galsim](https://github.com/GalSim-developers/JAX-GalSim) as a forward model to perform shear inference.


## Installation

```bash
Expand All @@ -13,15 +12,15 @@ pip install --upgrade pip
conda create -n bpd python=3.12
conda activate bpd

# Install JAX (cuda)
pip install -U "jax[cuda12]"
# Install JAX
pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || pip install -U "jax[cpu]"

# Install JAX-Galsim
pip install git+https://github.com/GalSim-developers/JAX-GalSim.git

# Install package and depedencies
git clone [email protected]:LSSTDESC/BPD.git
cd BPD
python -m pip install . -e
python -m pip install .[dev]
pip install -e .
pip install -e ".[dev]"
```
12 changes: 7 additions & 5 deletions bpd/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

import blackjax
import jax
from jax import random
from jax import Array, random
from jax._src.prng import PRNGKeyArray
from jax.typing import ArrayLike


def inference_loop(rng_key, initial_state, kernel, n_samples: int):
"""Function to run a single chain with a given kernel and obtain `n_samples`."""
def inference_loop(
rng_key: PRNGKeyArray, initial_state: ArrayLike, kernel: Callable, n_samples: int
):
"""Function to run a single chain with a given kernel and obtain samples"""

def one_step(state, rng_key):
state, info = kernel(rng_key, state)
Expand All @@ -32,7 +34,7 @@ def run_warmup_nuts(
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = True,
target_acceptance_rate: float = 0.8,
):
) -> tuple[ArrayLike, dict, dict]:
_logtarget = partial(logtarget, data=data)
warmup = blackjax.window_adaptation(
blackjax.nuts,
Expand Down Expand Up @@ -82,7 +84,7 @@ def run_inference_nuts(
n_warmup_steps: int = 500,
target_acceptance_rate: float = 0.80,
is_mass_matrix_diagonal: bool = True,
):
) -> Array | dict[str, Array]:
key1, key2 = random.split(rng_key)

_logtarget = partial(logtarget, data=data)
Expand Down
3 changes: 1 addition & 2 deletions bpd/diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import numpy as np
import pandas as pd
from chainconsumer import Chain, ChainConsumer, Truth
from jax import Array
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes
from numpyro.diagnostics import hpdi
from scipy import stats


def get_contour_plot(
samples_list: list[dict[str, Array]],
samples_list: list[dict[str, np.ndarray]],
names: list[str],
truth: dict[str, float],
figsize: tuple[float, float] = (7, 7),
Expand Down
4 changes: 2 additions & 2 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, vmap
from jax import Array, grad, vmap
from jax.numpy.linalg import norm
from jax.typing import ArrayLike

Expand All @@ -13,7 +13,7 @@


def shear_loglikelihood_unreduced(
g: tuple[float, float], e_post, prior: Callable, interim_prior: Callable
g: tuple[float, float], e_post: Array, prior: Callable, interim_prior: Callable
) -> ArrayLike:
# Given by the inference procedure in Schneider et al. 2014
# assume single shear g
Expand Down
11 changes: 7 additions & 4 deletions bpd/measure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import numpy as np
from jax.typing import ArrayLike
import jax.numpy as jnp
from jaxtyping import ArrayLike


def get_snr(im: ArrayLike, background: float) -> float:
"""Calculate the signal-to-noise ratio of an image.

Args:
im: Image array with no background.
im: 2D image array with no background.
background: Background level.

Returns:
float: The signal-to-noise ratio.
"""
assert im.ndim == 2
assert isinstance(background, float) or background.shape == ()
return np.sqrt(np.sum(im * im / (background + im)))
return jnp.sqrt(jnp.sum(im * im / (background + im)))
10 changes: 5 additions & 5 deletions bpd/pipelines/image_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ def loglikelihood(
background: float,
free_flux: bool = True,
):
# NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments.
_draw_params = {**{"g1": 0.0, "g2": 0.0}, **params} # function is more general
# NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments if fixed
_draw_params = {**{"g1": 0.0, "g2": 0.0}, **params}

# Convert log-flux to flux if provided
if free_flux:
_draw_params["f"] = 10 ** _draw_params.pop("lf")
model = draw_fnc(**_draw_params)

model = draw_fnc(**_draw_params)
likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
likelihood = jnp.sum(likelihood_pp)
return likelihood
return jnp.sum(likelihood_pp)


def logtarget(
Expand Down
14 changes: 9 additions & 5 deletions bpd/prior.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import jax.numpy as jnp
from jax import Array, random
from jax._src.prng import PRNGKeyArray
from jax.numpy.linalg import norm
from jaxtyping import ArrayLike


def ellip_mag_prior(e, sigma: float):
def ellip_mag_prior(e: ArrayLike, sigma: float):
"""Unnormalized Prior for the magnitude of the ellipticity, domain is (0, 1)

This distribution is taken from Gary's 2013 paper on Bayesian shear inference.
Expand All @@ -15,7 +17,9 @@ def ellip_mag_prior(e, sigma: float):
return (1 - e**2) ** 2 * jnp.exp(-(e**2) / (2 * sigma**2))


def sample_mag_ellip_prior(rng_key, sigma: float, n: int = 1, n_bins: int = 1_000_000):
def sample_mag_ellip_prior(
rng_key: PRNGKeyArray, sigma: float, n: int = 1, n_bins: int = 1_000_000
):
"""Sample n points from Gary's ellipticity magnitude prior."""
# this part could be cached
e_array = jnp.linspace(0, 1, n_bins)
Expand All @@ -25,7 +29,7 @@ def sample_mag_ellip_prior(rng_key, sigma: float, n: int = 1, n_bins: int = 1_00
return random.choice(rng_key, e_array, shape=(n,), p=p_array)


def sample_ellip_prior(rng_key, sigma: float, n: int = 1):
def sample_ellip_prior(rng_key: PRNGKeyArray, sigma: float, n: int = 1):
"""Sample n ellipticities isotropic components with Gary's prior from magnitude."""
key1, key2 = random.split(rng_key, 2)
e_mag = sample_mag_ellip_prior(key1, sigma=sigma, n=n)
Expand Down Expand Up @@ -98,7 +102,7 @@ def inv_shear_transformation(e: Array, g: tuple[float, float]):

# get synthetic measured sheared ellipticities
def sample_synthetic_sheared_ellips_unclipped(
rng_key,
rng_key: PRNGKeyArray,
g: tuple[float, float],
n: int,
sigma_m: float,
Expand All @@ -114,7 +118,7 @@ def sample_synthetic_sheared_ellips_unclipped(


def sample_synthetic_sheared_ellips_clipped(
rng_key,
rng_key: PRNGKeyArray,
g: tuple[float, float],
sigma_m: float,
sigma_e: float,
Expand Down
13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ description = "Bayesian Pixel Domain method for shear inference."
version = "0.0.1"
license = { file = "LICENSE" }
readme = "README.md"
dependencies = ["numpy >=1.18.0", "galsim >=2.3.0", "jax >=0.4.30", "jaxlib", "blackjax >=1.2.0"]
dependencies = [
"numpy >=1.18.0",
"galsim >=2.3.0",
"jax >=0.4.30",
"jaxlib",
"blackjax >=1.2.0",
"numpyro >=0.13.0",
]


[project.optional-dependencies]
Expand Down Expand Up @@ -58,8 +65,8 @@ exclude = [
line-length = 88
indent-width = 4

# Assume Python 3.8
target-version = "py310"
# Assume Python 3.12
target-version = "py312"

[tool.ruff.format]
# Like Black, use double quotes for strings.
Expand Down