-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate progress bar from fastprogress to tqdm #655
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,70 +14,83 @@ | |||||
"""Progress bar decorators for use with step functions. | ||||||
Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`. | ||||||
""" | ||||||
from fastprogress.fastprogress import progress_bar | ||||||
import jax | ||||||
from jax import lax | ||||||
from jax.experimental import io_callback | ||||||
from tqdm.auto import tqdm as tqdm_auto | ||||||
|
||||||
|
||||||
def progress_bar_scan(num_samples, print_rate=None): | ||||||
"Progress bar for a JAX scan" | ||||||
progress_bars = {} | ||||||
def progress_bar_scan(num_samples, num_chains=1, print_rate=None): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC in the usage we need to specify the Line 198 in 7cf4f9d
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not fully committed to this API, but I was thinking something where along with passing an array of iteration numbers, you also pass in the chain you are currently in. I think this is better than the numpyro design where you are using regexes on device objects to guess what chain to put the computation on. def inference_loop(rng_key, kernel, initial_state, chain, num_samples, num_chains):
def _one_step(state, xs):
_, _, rng_key = xs
state, _ = kernel(rng_key, state)
return state, state
one_step = jax.jit(progress_bar_factory(num_samples, num_chains)(_one_step))
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(
one_step,
initial_state,
(np.arange(num_samples), chain * np.ones(num_samples), keys),
)
return states
inference_loop_multiple_chains = jax.pmap(
inference_loop,
in_axes=(0, None, 0, 0, None, None),
static_broadcasted_argnums=(1, 4, 5),
devices=jax.devices(),
) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For downstream applications that don't use multiple chains, I have included logic to maintain backward compatibility. Though I'm not sure how actual code is implementing progress bars for multiple chains today. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, could you share a small jupyter notebook how it looks like? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a self-contained example https://gist.github.com/zaxtax/5fd7c881c6ac83a7ca2798d0a7e230b7 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, that's very helpful. Let me think about it a bit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfectly happy to rework the API. This is was an attempt to make something simple and backwards compatible. |
||||||
"""Factory that builds a progress bar decorator along | ||||||
with the `set_tqdm_description` and `close_tqdm` functions | ||||||
""" | ||||||
|
||||||
if print_rate is None: | ||||||
if num_samples > 20: | ||||||
print_rate = int(num_samples / 20) | ||||||
else: | ||||||
print_rate = 1 # if you run the sampler for less than 20 iterations | ||||||
|
||||||
def _define_bar(arg): | ||||||
del arg | ||||||
progress_bars[0] = progress_bar(range(num_samples)) | ||||||
progress_bars[0].update(0) | ||||||
remainder = num_samples % print_rate | ||||||
|
||||||
def _update_bar(arg): | ||||||
progress_bars[0].update_bar(arg + 1) | ||||||
tqdm_bars = {} | ||||||
for chain in range(num_chains): | ||||||
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain) | ||||||
tqdm_bars[chain].set_description("Compiling.. ", refresh=True) | ||||||
|
||||||
def _close_bar(arg): | ||||||
del arg | ||||||
progress_bars[0].on_iter_end() | ||||||
def _update_tqdm(arg, chain): | ||||||
chain = int(chain) | ||||||
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False) | ||||||
tqdm_bars[chain].update(arg) | ||||||
|
||||||
def _close_tqdm(arg, chain): | ||||||
chain = int(chain) | ||||||
tqdm_bars[chain].update(arg) | ||||||
tqdm_bars[chain].close() | ||||||
|
||||||
def _update_progress_bar(iter_num, chain): | ||||||
"""Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate | ||||||
Usage: carry = progress_bar((iter_num, print_rate), carry) | ||||||
""" | ||||||
|
||||||
def _update_progress_bar(iter_num): | ||||||
"Updates progress bar of a JAX scan or loop" | ||||||
_ = lax.cond( | ||||||
iter_num == 0, | ||||||
lambda _: io_callback(_define_bar, None, iter_num), | ||||||
lambda _: jax.debug.callback(_update_tqdm, iter_num, chain), | ||||||
lambda _: None, | ||||||
operand=None, | ||||||
) | ||||||
|
||||||
_ = lax.cond( | ||||||
# update every multiple of `print_rate` except at the end | ||||||
(iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), | ||||||
lambda _: io_callback(_update_bar, None, iter_num), | ||||||
(iter_num % print_rate) == 0, | ||||||
lambda _: jax.debug.callback(_update_tqdm, print_rate, chain), | ||||||
lambda _: None, | ||||||
operand=None, | ||||||
) | ||||||
|
||||||
_ = lax.cond( | ||||||
iter_num == num_samples - 1, | ||||||
lambda _: io_callback(_close_bar, None, None), | ||||||
lambda _: jax.debug.callback(_close_tqdm, remainder, chain), | ||||||
lambda _: None, | ||||||
operand=None, | ||||||
) | ||||||
|
||||||
def _progress_bar_scan(func): | ||||||
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`. | ||||||
Note that `body_fun` must either be looping over `np.arange(num_samples)`, | ||||||
or be looping over a tuple who's first element is `np.arange(num_samples)` | ||||||
looping over a tuple whose elements are `np.arange(num_samples), and a | ||||||
chain id defined as `chain * np.ones(num_samples)`, or be looping over a | ||||||
tuple who's first element and second elements include iter_num and chain. | ||||||
This means that `iter_num` is the current iteration number | ||||||
""" | ||||||
|
||||||
def wrapper_progress_bar(carry, x): | ||||||
if type(x) is tuple: | ||||||
iter_num, *_ = x | ||||||
if num_chains > 1: | ||||||
iter_num, chain, *_ = x | ||||||
else: | ||||||
iter_num, *_ = x | ||||||
chain = 0 | ||||||
else: | ||||||
iter_num = x | ||||||
_update_progress_bar(iter_num) | ||||||
chain = 0 | ||||||
_update_progress_bar(iter_num, chain) | ||||||
return func(carry, x) | ||||||
|
||||||
return wrapper_progress_bar | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from jax.debug import callback
?