Skip to content

Commit

Permalink
burn in and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Aug 9, 2024
2 parents e144ce7 + 148c028 commit 312e746
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
17 changes: 5 additions & 12 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
dual_averaging_adaptation,
)
from blackjax.base import AdaptationAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.progress_bar import gen_scan_fn
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size

Expand Down Expand Up @@ -333,23 +333,16 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):

if progress_bar:
print("Running window adaptation")
one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))
start_state = ((init_state, init_adaptation_state), -1)
else:
one_step_ = jax.jit(one_step)
start_state = (init_state, init_adaptation_state)

scan_fn = gen_scan_fn(num_steps, progress_bar=progress_bar)
start_state = (init_state, init_adaptation_state)
keys = jax.random.split(rng_key, num_steps)
schedule = build_schedule(num_steps)
last_state, info = jax.lax.scan(
one_step_,
last_state, info = scan_fn(
one_step,
start_state,
(jnp.arange(num_steps), keys, schedule),
)

if progress_bar:
last_state, _ = last_state

last_chain_state, last_warmup_state, *_ = last_state

step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
Expand Down
14 changes: 14 additions & 0 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,17 @@ def wrapper_progress_bar(carry, x):
return wrapper_progress_bar

return _progress_bar_scan


def gen_scan_fn(num_samples, progress_bar, print_rate=None):
if progress_bar:

def scan_wrap(f, init, *args, **kwargs):
func = progress_bar_scan(num_samples, print_rate)(f)
carry = (init, -1)
(last_state, _), output = lax.scan(func, carry, *args, **kwargs)
return last_state, output

return scan_wrap
else:
return lax.scan
18 changes: 10 additions & 8 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax.tree_util import tree_leaves

from blackjax.base import SamplingAlgorithm, VIAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.progress_bar import gen_scan_fn
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


Expand Down Expand Up @@ -200,13 +200,15 @@ def one_step(state, xs):
state, info = inference_algorithm.step(rng_key, state)
return state, transform(state, info)

if progress_bar:
one_step = progress_bar_scan(num_steps)(one_step)
xs = jnp.arange(num_steps), keys
final_state, history = lax.scan(one_step, (initial_state, -1), xs)
else:
xs = jnp.arange(num_steps), keys
final_state, history = lax.scan(one_step, initial_state, xs)
scan_fn = gen_scan_fn(num_steps, progress_bar)

# if progress_bar:
# one_step = progress_bar_scan(num_steps)(one_step)
# xs = jnp.arange(num_steps), keys
# final_state, history = lax.scan(one_step, (initial_state, -1), xs)

xs = jnp.arange(num_steps), keys
final_state, history = scan_fn(one_step, initial_state, xs)

return final_state, history

Expand Down

0 comments on commit 312e746

Please sign in to comment.