Skip to content

Commit

Permalink
revise text and skeletons to new API
Browse files Browse the repository at this point in the history
  • Loading branch information
albcab committed Feb 22, 2024
1 parent 6f94b94 commit 8b2edea
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 38 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ passing parameters.

### New algorithms

We hope to make implementing and testing new algorithms easy with BlackJAX. Many basic methods are already implemented in the library, and you can use them to test new algorithms. Follow the [guidelines](https://blackjax-devs.github.io/blackjax/developer/guidelines.html) to implement your own method and test new ideas on existing methods without writing everything from scratch!
We want to make implementing and testing new algorithms easy with BlackJAX. You can test new algorithms by reusing the basic components of the many known methods already implemented in the library. Follow the [guidelines](https://blackjax-devs.github.io/blackjax/developer/guidelines.html) to implement your method and test new ideas on existing methods without writing everything from scratch.

## Contributions

Expand Down
6 changes: 3 additions & 3 deletions docs/developer/approximate_inf_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# import basic compoments that are already implemented
# or that you have implemented with a general structure
from blackjax.base import VIAlgorithm
from blackjax.types import PRNGKey, PyTree
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = [
"ApproxInfState",
Expand Down Expand Up @@ -49,7 +49,7 @@ class ApproxInfInfo(NamedTuple):
...


def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs):
def init(position: ArrayLikeTree, logdensity_fn: Callable, *args, **kwargs):
# build an inital state
state = ApproxInfState(...)
return state
Expand Down Expand Up @@ -116,7 +116,7 @@ def __new__( # type: ignore[misc]
*args,
**kwargs,
) -> VIAlgorithm:
def init_fn(position: PyTree):
def init_fn(position: ArrayLikeTree):
return cls.init(position, optimizer, ...)

def step_fn(rng_key: PRNGKey, state):
Expand Down
43 changes: 22 additions & 21 deletions docs/developer/guidelines.md
Original file line number Diff line number Diff line change
@@ -1,43 +1,44 @@
# Developer Guidelines

In the broadest sense, an algorithm that belongs in the BlackJAX library should provide the tools to approximate integrals on a probability space. An introduction to probability theory is outside the scope of this document, but the Monte Carlo method is ever-present and important to understand. In simple terms, we want to approximate an integral with a sum. To do this, generate samples with [relative likelihood](https://en.wikipedia.org/wiki/Relative_likelihood) given by a target probability density function (known up to a normalization constant). The idea is to sample more from areas with higher likelihood but also from areas with low likelihood, just at a lower rate. You can also approximate the target density directly, using a density that is tractable and easy to sample from, then do inference with the approximation instead of the target, potentially using [importance sampling](https://en.wikipedia.org/wiki/Importance_sampling) to correct the approximation error.

In the following section, we’ll explain BlackJAX’s design of different algorithms for Monte Carlo integration. Keep in mind some basic principles:

- Leverage JAX's unique strengths: functional programming and composable function-transformation approach.
- Write small and general functions, compose them to create complex methods, reuse the same building blocks for similar algorithms.
- Write small and general functions, compose them to create complex methods, and reuse the same building blocks for similar algorithms.
- Consider compatibility with the broader JAX ecosystem (Flax, Optax, GPJax).
- Write code that is easy to read and understand.
- Write code that is well documented, describe in detail the inner mechanism of the algorithm and its use.
- Write well-documented code describing in detail the inner mechanism of the algorithm and its use.

## Core implementation
There are three types of sampling algorithms BlackJAX currently supports: Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), and Stochastic Gradient MCMC (SGMCMC); and one type of approximate inference algorithm: Variational Inference (VI). Additionally, BlackJAX supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples.

Basic components are functions which do specific tasks but are generally applicable, used to build all inference algorithms. When implementing a new inference algorithm you should first break it down to its basic components then find and use all that are already implemented *before* writing your own. A recurrent example is the [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In BlackJAX there are two basic components that do a specific (but simpler) and a general version of this accept/reject step:

- Metropolis step: if the proposal transition kernel is symmetric, i.e. if the probability of going from the initial to the proposed position is always equal to the probability of going from the proposed to the initial position, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`.
- Metropolis-Hastings step: for the more general case of an asymmetric proposal transition kernel, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.asymmetric_proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`.

When implementing an algorithm you could choose to replace the classic, reversible Metropolis-Hastings step with Neal's [non-reversible slice sampling](https://arxiv.org/abs/2001.11950) step by simply replacing `mcmc.proposal.static_binomial_sampling` with `mcmc.proposal.nonreversible_slice_sampling` on either of the previous implementations. Just make sure to carry over to the next iteration an updated slice, instead of passing a pseudo-random number generating key, for the slice sampling step!

The previous example illustrates the power of basic components, useful not only to avoid rewriting the same methods for each new algorithm but also useful to personalize and test new algorithms which replace some steps of common efficient algorithms. Like how `blackjax.mcmc.ghmc` is `blackjax.mcmc.hmc` with a persistent momentum and a non-reversible slice sampling step instead of the Metropolis-Hastings step.

Because JAX operates on pure functions, inference algorithms always return a `typing.NamedTuple` containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of BlackJAX, so it must be done in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a `typing.NamedTuple` with important information about each iteration.
BlackJAX supports sampling algorithms such as Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), Stochastic Gradient MCMC (SGMCMC), and approximate inference algorithms such as Variational Inference (VI). In all cases, BlackJAX takes a Markovian approach, whereby its current state contains all the information to obtain the next iteration of an algorithm. This naturally results in a functionally pure structure, where no side-effects are allowed, simplifying parallelisation. Additionally, BlackJAX supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples.

The user-facing interface of a **sampling algorithm** should work like this:
The user-facing interface of a **sampling algorithm** is made up of an initializer and an iterator:
```python
# Generic sampling algorithm:
sampling_algorithm = blackjax.sampling_algorithm(logdensity_fn, *args, **kwargs)
state = sampling_algorithm.init(initial_position)
new_state, info = sampling_algorithm.step(rng_key, state)
```
Achieve this by building from the basic skeleton of a sampling algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py). Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary.
Build from the basic skeleton of a sampling algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py). Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm; the rest might be useful but are not necessary.

The user-facing interface of an **approximate inference algorithm** should work like this:
The user-facing interface of an **approximate inference algorithm** is made up of an initializer, iterator, and sampler:
```python
# Generic approximate inference algorithm:
approx_inf_algorithm = blackjax.approx_inf_algorithm(logdensity_fn, optimizer, *args, **kwargs)
state = approx_inf_algorithm.init(initial_position)
new_state, info = approx_inf_algorithm.step(rng_key, state)
#user is able to build the approximate distribution using the state, or generate samples:
# user is able to build the approximate distribution using the state, or generate samples:
position_samples = approx_inf_algorithm.sample(rng_key, state, num_samples)
```
Achieve this by building from the basic skeleton of an approximate inference algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py). Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary.
Build from the basic skeleton of an approximate inference algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py). Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm; the rest might be useful but are not necessary.

## Basic components
All inference algorithms are composed of basic components which provide the lowest level of algorithm abstraction and are available to the user. When implementing a new inference algorithm, you should first break it down to its basic components, then find and use all already implemented *before* writing your own. A recurrent example is the [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In BlackJAX, two basic components do a specific (but simpler) and a general version of this accept/reject step:

- Metropolis step: if the proposal transition kernel is symmetric, i.e. if the probability of going from the initial to the proposed position is always equal to the probability of going from the proposed to the initial position, the acceptance probability is calculated using `mcmc.proposal.safe_energy_diff`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For example, see `mcmc.hmc.hmc_proposal`.
- Metropolis-Hastings step: for the more general case of an asymmetric proposal transition kernel, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.compute_asymmetric_acceptance_ratio`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For example, see `mcmc.mala.build_kernel`.

When implementing an algorithm you could choose to replace the classic, reversible Metropolis-Hastings step with Neal's [non-reversible slice sampling](https://arxiv.org/abs/2001.11950) step by simply replacing `mcmc.proposal.static_binomial_sampling` with `mcmc.proposal.nonreversible_slice_sampling` on either of the previous implementations. Make sure to carry over to the next iteration an updated slice for the slice sampling step, instead of passing a pseudo-random number generating key!

The previous example illustrates the power of basic components, useful not only to avoid rewriting the same methods for each new algorithm but also to personalize and test new algorithms that replace some steps of standard efficient algorithms, like how `blackjax.mcmc.ghmc` is `blackjax.mcmc.hmc` only with a persistent momentum and a non-reversible slice sampling step instead of the Metropolis-Hastings step.

Because JAX operates on pure functions, inference algorithms always return a `typing.NamedTuple` containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of BlackJAX, so you must do it in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a `typing.NamedTuple` with important information about each iteration.
27 changes: 14 additions & 13 deletions docs/developer/sampling_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
# or that you have implemented with a general structure
# for example, if you do a Metropolis-Hastings accept/reject step:
import blackjax.mcmc.proposal as proposal
from blackjax.base import MCMCSamplingAlgorithm
from blackjax.types import PRNGKey, PyTree
from blackjax.base import SamplingAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = [
"SamplingAlgoState",
Expand Down Expand Up @@ -49,7 +49,7 @@ class SamplingAlgoInfo(NamedTuple):
...


def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs):
def init(position: ArrayLikeTree, logdensity_fn: Callable, *args, **kwargs):
# build an inital state
state = SamplingAlgoState(...)
return state
Expand Down Expand Up @@ -117,10 +117,10 @@ def __new__( # type: ignore[misc]
logdensity_fn: Callable,
*args,
**kwargs,
) -> MCMCSamplingAlgorithm:
) -> SamplingAlgorithm:
kernel = cls.build_kernel(...)

def init_fn(position: PyTree):
def init_fn(position: ArrayLikeTree):
return cls.init(position, logdensity_fn, ...)

def step_fn(rng_key: PRNGKey, state):
Expand All @@ -131,7 +131,7 @@ def step_fn(rng_key: PRNGKey, state):
...,
)

return MCMCSamplingAlgorithm(init_fn, step_fn)
return SamplingAlgorithm(init_fn, step_fn)


# and other functions that help make `init` and/or `build_kernel` easier to read and understand
Expand All @@ -148,20 +148,21 @@ def sampling_algorithm_proposal(*args, **kwags) -> Callable:
-------
Describe what is returned.
"""
# as an example, a Metropolis-Hastings step would look like this:
init_proposal, generate_proposal = proposal.proposal_generator(...)
sample_proposal = proposal.static_binomial_sampling(...)
# as an example, a Metropolis-Hastings step with symmetric a symmetric transition would look like this:
acceptance_ratio = proposal.safe_energy_diff
sample_proposal = proposal.static_binomial_sampling

def generate(rng_key, state):
# propose a new sample
proposal_state = ...

# accept or reject the proposed sample
proposal = init_proposal(state)
new_proposal, is_diverging = generate_proposal(proposal.energy, proposal_state)
sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal)
initial_energy = ...
proposal_energy = ...
new_proposal, is_diverging = acceptance_ratio(initial_energy, proposal_energy)
sampled_state, info = sample_proposal(rng_key, proposal, new_proposal)

# build a new state and collect useful information
# maybe add to the returned state and collect more useful information
sampled_state, info = ...

return sampled_state, info
Expand Down

0 comments on commit 8b2edea

Please sign in to comment.