diff --git a/README.md b/README.md index a8d847cf9..9590b4cd6 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,10 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) +step = jax.jit(nuts.step) for step in range(100): nuts_key = jax.random.fold_in(rng_key, step) - state, _ = nuts.step(nuts_key, state) + state, _ = step(nuts_key, state) ``` See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.