Skip to content

Commit

Permalink
various test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Nov 5, 2024
1 parent 38903ee commit a562322
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,15 @@ def test_low_noise_single_galaxy_interim_samples(seed):

draw_params = {**galaxy_params}
draw_params["f"] = 10 ** draw_params.pop("lf")
target_image = get_target_images_single(
target_image, _ = get_target_images_single(
nkey,
n_samples=1,
single_galaxy_params=draw_params,
background=background,
slen=slen,
)[0]
)
assert target_image.shape == (1, slen, slen)

true_params = get_true_params_from_galaxy_params(galaxy_params)

pipe1 = partial(
Expand All @@ -178,10 +180,13 @@ def test_low_noise_single_galaxy_interim_samples(seed):

# chain initialization
# one galaxy, test convergence, so 4 random seeds
keys = random.split(gkey, 4)
init_positions = vmap(init_fnc, (0, None))(keys, true_params)
gkey1, gkey2 = random.split(gkey, 2)
keys1 = random.split(gkey1, 4)
keys2 = random.split(gkey2, 4)

init_positions = vmap(init_fnc, (0, None))(keys1, true_params)

samples = vpipe1(keys, init_positions, target_image)
samples = vpipe1(keys2, init_positions, target_image)

# check each component
for _, v in samples.items():
Expand Down

0 comments on commit a562322

Please sign in to comment.