Skip to content

Commit

Permalink
return info in separate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Nov 18, 2024
1 parent c4c517a commit e9eef09
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions bpd/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def run_warmup_nuts(
target_acceptance_rate=target_acceptance_rate,
)

(init_states, tuned_params), _ = warmup.run(rng_key, init_positions, n_warmup_steps)
return init_states, tuned_params
(init_states, tuned_params), adapt_info = warmup.run(
rng_key, init_positions, n_warmup_steps
)
return init_states, tuned_params, adapt_info


def run_sampling_nuts(
Expand All @@ -62,8 +64,10 @@ def run_sampling_nuts(
kernel = blackjax.nuts(
_logtarget, **tuned_params, max_num_doublings=max_num_doublings
).step
states, _ = inference_loop(rng_key, init_states, kernel=kernel, n_samples=n_samples)
return states.position
states, info = inference_loop(
rng_key, init_states, kernel=kernel, n_samples=n_samples
)
return states.position, info


def run_inference_nuts(
Expand Down

0 comments on commit e9eef09

Please sign in to comment.