Skip to content

Commit

Permalink
draft of exp 2 with galaxies between snr (8, 100)
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Nov 19, 2024
1 parent 549095f commit 949f3a8
Showing 1 changed file with 143 additions and 0 deletions.
143 changes: 143 additions & 0 deletions experiments/exp2/run_inference_galaxy_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3
"""Check chains ran on a variety of galaxies with different SNR, initialization from the prior."""

import time
from functools import partial

import jax.numpy as jnp
import typer
from jax import jit as jjit
from jax import random, vmap
from jax._src.prng import PRNGKeyArray

from bpd import DATA_DIR
from bpd.chains import run_sampling_nuts, run_warmup_nuts
from bpd.draw import draw_gaussian
from bpd.initialization import init_with_prior
from bpd.pipelines.image_samples import (
get_target_galaxy_params_simple,
get_target_images,
get_true_params_from_galaxy_params,
loglikelihood,
logprior,
logtarget,
)


def sample_prior(
rng_key: PRNGKeyArray,
*,
flux_bds: tuple = (2.5, 4.0),
hlr_bds: tuple = (0.7, 2.0),
shape_noise: float = 0.3,
g1: float = 0.02,
g2: float = 0.0,
) -> dict[str, float]:
k1, k2, k3 = random.split(rng_key, 3)

lf = random.uniform(k1, minval=flux_bds[0], maxval=flux_bds[1])
hlr = random.uniform(k2, minval=hlr_bds[0], maxval=hlr_bds[1])

other_params = get_target_galaxy_params_simple(
k3, shape_noise=shape_noise, g1=g1, g2=g2
)

return {"lf": lf, "hlr": hlr, **other_params}


def _sample_prior_init(rng_key: PRNGKeyArray):
prior_samples = sample_prior(rng_key)
truth_samples = get_true_params_from_galaxy_params(prior_samples)
return truth_samples


INIT_FNC = partial(init_with_prior, prior=_sample_prior_init)


def main(
seed: int,
n_samples: int = 100,
shape_noise: float = 0.3,
sigma_e_int: float = 0.5,
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
initial_step_size: float = 0.1,
):
rng_key = random.key(seed)
pkey, nkey, ikey, rkey = random.split(rng_key, 4)

# directory structure
dirpath = DATA_DIR / "cache_chains" / f"test_image_sampling_{seed}"
if not dirpath.exists():
dirpath.mkdir(exist_ok=True)

draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size)
_loglikelihood = partial(loglikelihood, draw_fnc=draw_fnc, background=background)
_logprior = partial(logprior, sigma_e=sigma_e_int)
_logtarget = partial(
logtarget, logprior_fnc=_logprior, loglikelihood_fnc=_loglikelihood
)

_run_warmup1 = partial(
run_warmup_nuts,
logtarget=_logtarget,
initial_step_size=initial_step_size,
max_num_doublings=5,
n_warmup_steps=500,
)
_run_warmup = vmap(vmap(jjit(_run_warmup1), in_axes=(0, 0, None)))

_run_sampling1 = partial(
run_sampling_nuts,
logtarget=_logtarget,
n_samples=n_samples,
max_num_doublings=5,
)
_run_sampling = vmap(vmap(jjit(_run_sampling1), in_axes=(0, 0, 0, None)))

results = {}
for n_gals in (1, 1, 5, 10, 50, 100, 250): # repeat 1 == compilation
# generate data and parameters
pkeys = random.split(pkey, n_gals)
galaxy_params = vmap(partial(sample_prior, shape_noise=shape_noise))(pkeys)
assert galaxy_params["x"].shape == (n_gals,)

draw_params = {**galaxy_params}
draw_params["f"] = 10 ** draw_params.pop("lf")
target_images = get_target_images(
nkey, draw_params, background=background, slen=slen
)
assert target_images.shape == (n_gals, slen, slen)
true_params = vmap(get_true_params_from_galaxy_params)(galaxy_params)

# initialize positions
ikeys = random.split(ikey, (n_gals, 4))
init_positions = vmap(vmap(INIT_FNC, in_axes=(0, None)))(ikeys, true_params)

gkeys = random.split(rkey, (n_gals, 4, 2))
wkeys = gkeys[..., 0]
ikeys = gkeys[..., 1]

# warmup
t1 = time.time()
init_states, tuned_params, _ = _run_warmup(wkeys, init_positions, target_images)
t2 = time.time()
t_warmup = t2 - t1
tuned_params.pop("max_num_doublings") # set above, not jittable

# inference
t1 = time.time()
samples, _ = _run_sampling(ikeys, init_states, tuned_params, target_images)
t2 = time.time()
t_sampling = t2 - t1

results[n_gals]["t_warmup"] = t_warmup
results[n_gals]["t_sampling"] = t_sampling
results[n_gals]["samples"] = samples

jnp.save(results)


if __name__ == "__main__":
typer.run(main)

0 comments on commit 949f3a8

Please sign in to comment.