diff --git a/docs/developer/principles.md b/docs/developer/principles.md new file mode 100644 index 000000000..372c2b7a9 --- /dev/null +++ b/docs/developer/principles.md @@ -0,0 +1,35 @@ +# Principles of design + +The design of the library is to achieve, in no particular order: +- A library of algorithms used to approximate integrals on probability spaces (***not*** a PPL). +- Code that is easy to read and understand. +- Flexibility for the user, i.e. construction of meta algorithms (abstractions?). +- Well documented code, that describes in detail the inner mechanism of the algorithm. +- Benchmarks for quality of approximations and algorithm efficiency. +- Leverage JAX's unique strengths (functional programming, composable function-transformation approach). + +Quoting from [Flax's Linen's design principles](https://flax.readthedocs.io/en/latest/philosophy.html): + +> Arguably the entire point of a \[sampling\] library is to offer an implicit variable management API to save the user from having to manually thread thousands of variables through a complex tree of functions. However, JAX operates on pure functions. + +Of course, in our case, the number of variables will seldom reach the thousands; nonetheless, abstracting the variable management *while* allowing the user to create by composing base algorithms should be the main objective of the library. In fact, the end user should have the ability to create algorithms with no ergodicity guarantees or that fail to leave the target distribution invariant. The probabilistic properties of the algorithm, like the Markov chain it creates or the features of the density it uses to approximate, should be the focus of the end user and its research or testing objectives. Blackjax provides only the tools to create, compose and develop. For example, HMC simulates a discretized version of Hamiltonian dynamics to then ensure detailed balance with a Metropolis-Hastings step, and it should be straightforward to do unadjusted HMC by simply removing the Metropolis-Hastings step from the algorithm, or to add adaptation by including and extra adaptation step. + +```python +#Hamiltonian Monte Carlo +init_hmc, step_hmc = blackjax.sequential([blackjax.velocity_verlet, blackjax.metropolis_hastings])(logdensity_fn) +variables = init_hmc(initial_position, step_size) +new_variables = step_hmc(variables) + +#Unadjusted Hamiltonian Monte Carlo +init_uhmc, step_uhmc = blackjax.velocity_verlet(logdensity_fn) +variables = init_uhmc(initial_position, step_size) +new_variables = step_uhmc(variables) + +#Hamiltonain Monte Carlo with Window Adaptation +init_adapt_hmc, step_adapt_hmc = blackjax.sequential([blackjax.velocity_verlet, blackjax.metropolis_hastings, blackjax.window_adaptation])(logdensity_fn) +#warm-up +variables = init_adapt_hmc(initial_position, step_size) +new_variables = step_adapt_hmc(variables) +#sampling +sample_variables = step_uhmc(new_variables) +``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 6b0dc6cb2..f71dd577a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -130,3 +130,12 @@ vi adaptation diagnostics ``` + +```{toctree} +--- +maxdepth: 1 +caption: DEVELOPER DOCUMENTATION +hidden: +--- +Principles +```