Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add guidelines for new algorithms (docs and skeletons) #485

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ repos:
rev: 6.0.0
hooks:
- id: flake8
exclude: |
(?x)^(
docs/developer/approximate_inf_algorithm.py|
docs/developer/sampling_algorithm.py
)$
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
Expand All @@ -33,6 +38,11 @@ repos:
rev: v1.0.1
hooks:
- id: mypy
exclude: |
(?x)^(
docs/developer/approximate_inf_algorithm.py|
docs/developer/sampling_algorithm.py
)$
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.3.1
hooks:
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ information related to the transition are returned separately. They can thus be
easily composed and exchanged. We specialize these kernels by closure instead of
passing parameters.

## New algorithms

We want to make implementing and testing new algorithms easy with BlackJAX. You can test new algorithms by reusing the basic components of the many known methods already implemented in the library. Follow the [developer guidelines](https://blackjax-devs.github.io/blackjax/developer/guidelines.html) to implement your method and test new ideas on existing methods without writing everything from scratch.

## Contributions

Please follow our [short guide](https://github.com/blackjax-devs/blackjax/blob/main/CONTRIBUTING.md).
Expand Down
137 changes: 137 additions & 0 deletions docs/developer/approximate_inf_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple

import jax
from optax import GradientTransformation

# import basic compoments that are already implemented
# or that you have implemented with a general structure
from blackjax.base import VIAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = [
"ApproxInfState",
"ApproxInfInfo",
"init",
"sample",
"step",
"approx_inf_algorithm",
]


class ApproxInfState(NamedTuple):
"""State of your approximate inference algorithm.

Give an overview of the variables needed at each step and for sampling.
"""

...


class ApproxInfInfo(NamedTuple):
"""Additional information on your algorithm transition.

Give an overview of the collected values at each step of the approximation.
"""

...


def init(position: ArrayLikeTree, logdensity_fn: Callable, *args, **kwargs):
# build an inital state
state = ApproxInfState(...)
return state


def step(
rng_key: PRNGKey,
state: ApproxInfInfo,
logdensity_fn: Callable,
optimizer: GradientTransformation,
*args,
**kwargs,
) -> Tuple[ApproxInfState, ApproxInfInfo]:
"""Approximate the target density using your approximation.

Parameters
----------
List and describe its parameters.
"""
# extract the previous parameters from the state
params = ...
# generate pseudorandom keys
key_other, key_update = jax.random.split(rng_key, 2)
# update the parameters and build a new state
new_state = ApproxInfState(...)
info = ApproxInfInfo(...)

return new_state, info


def sample(rng_key: PRNGKey, state: ApproxInfState, num_samples: int = 1):
"""Sample from your approximation."""
# the sample should be a PyTree of the same structure as the `position` in the init function
samples = ...
return samples


class approx_inf_algorithm:
"""Implements the (basic) user interface for your approximate inference method.

Describe in detail the inner mechanism of the method and its use.

Example
-------
Illustrate the use of the algorithm.

Parameters
----------
List and describe its parameters.

Returns
-------
A ``VIAlgorithm``.
"""

init = staticmethod(init)
step = staticmethod(step)
sample = staticmethod(sample)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
optimizer: GradientTransformation,
*args,
**kwargs,
) -> VIAlgorithm:
def init_fn(position: ArrayLikeTree):
return cls.init(position, optimizer, ...)

def step_fn(rng_key: PRNGKey, state):
return cls.step(
rng_key,
state,
logdensity_fn,
optimizer,
...,
)

def sample_fn(rng_key: PRNGKey, state, num_samples):
return cls.sample(rng_key, state, num_samples)

return VIAlgorithm(init_fn, step_fn, sample_fn)


# other functions that help make `init`,` `step` and/or `sample` easier to read and understand
43 changes: 43 additions & 0 deletions docs/developer/guidelines.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Developer Guidelines

In the following section, we’ll explain BlackJAX’s design of different algorithms for Monte Carlo integration. Keep in mind some basic principles:

- Leverage JAX's unique strengths: functional programming and composable function-transformation approach.
- Write small and general functions, compose them to create complex methods, and reuse the same building blocks for similar algorithms.
- Consider compatibility with the broader JAX ecosystem (Flax, Optax, GPJax).
- Write code that is easy to read and understand.
- Write well-documented code describing in detail the inner mechanism of the algorithm and its use.

## Core implementation
BlackJAX supports sampling algorithms such as Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), Stochastic Gradient MCMC (SGMCMC), and approximate inference algorithms such as Variational Inference (VI). In all cases, BlackJAX takes a Markovian approach, whereby its current state contains all the information to obtain the next iteration of an algorithm. This naturally results in a functionally pure structure, where no side-effects are allowed, simplifying parallelisation. Additionally, BlackJAX supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples.

The user-facing interface of a **sampling algorithm** is made up of an initializer and an iterator:
```python
# Generic sampling algorithm:
sampling_algorithm = blackjax.sampling_algorithm(logdensity_fn, *args, **kwargs)
state = sampling_algorithm.init(initial_position)
new_state, info = sampling_algorithm.step(rng_key, state)
```
Build from the basic skeleton of a sampling algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py). Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm; the rest might be useful but are not necessary.

The user-facing interface of an **approximate inference algorithm** is made up of an initializer, iterator, and sampler:
```python
# Generic approximate inference algorithm:
approx_inf_algorithm = blackjax.approx_inf_algorithm(logdensity_fn, optimizer, *args, **kwargs)
state = approx_inf_algorithm.init(initial_position)
new_state, info = approx_inf_algorithm.step(rng_key, state)
position_samples = approx_inf_algorithm.sample(rng_key, state, num_samples)
```
Build from the basic skeleton of an approximate inference algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py). Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm; the rest might be useful but are not necessary.

## Basic components
All inference algorithms are composed of basic components which provide the lowest level of algorithm abstraction and are available to the user. When implementing a new inference algorithm, you should first break it down to its basic components, then find and use all already implemented *before* writing your own. For example, the [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In BlackJAX, two basic components do a specific (but simpler) and a general version of this accept/reject step:

- Metropolis step: if the proposal transition kernel is symmetric, i.e. if the probability of going from the initial to the proposed position is always equal to the probability of going from the proposed to the initial position, the acceptance probability is calculated using `mcmc.proposal.safe_energy_diff`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For instance, see `mcmc.hmc.hmc_proposal`.
- Metropolis-Hastings step: for the more general case of an asymmetric proposal transition kernel, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.compute_asymmetric_acceptance_ratio`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For instance, see `mcmc.mala.build_kernel`.

When implementing an algorithm you could choose to replace the reversible binomial sampling step with Neal's [non-reversible slice sampling](https://arxiv.org/abs/2001.11950) step by simply replacing `mcmc.proposal.static_binomial_sampling` with `mcmc.proposal.nonreversible_slice_sampling` on either of the previous implementations. Make sure to carry over to the next iteration an updated slice for the slice sampling step, instead of passing a pseudo-random number generating key!

The previous example illustrates the practicality of basic components: they avoid rewriting the same methods and allow to easily test new algorithms that customize established algorithms, like how `blackjax.mcmc.ghmc` is `blackjax.mcmc.hmc` only with a persistent momentum and a non-reversible slice sampling step instead of the static binomial sampling step.

Because JAX operates on pure functions, inference algorithms always return a `typing.NamedTuple` containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of BlackJAX, so you must do it in a way that abstracts the uninteresting bookkeeping and allows access to important variables at each step. The algorithms should also return a `typing.NamedTuple` with important information about each iteration.
170 changes: 170 additions & 0 deletions docs/developer/sampling_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple

import jax

# import basic compoments that are already implemented
# or that you have implemented with a general structure
# for example, if you do a Metropolis-Hastings accept/reject step:
import blackjax.mcmc.proposal as proposal
from blackjax.base import SamplingAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = [
"SamplingAlgoState",
"SamplingAlgoInfo",
"init",
"build_kernel",
"sampling_algorithm",
]


class SamplingAlgoState(NamedTuple):
"""State of your sampling algorithm.

Give an overview of the variables needed at each iteration of the model.
"""

...


class SamplingAlgoInfo(NamedTuple):
"""Additional information on your algorithm transition.

Given an overview of the collected values at each iteration of the model.
"""

...


def init(position: ArrayLikeTree, logdensity_fn: Callable, *args, **kwargs):
# build an inital state
state = SamplingAlgoState(...)
return state


def build_kernel(*args, **kwargs):
"""Build a your kernel.

Parameters
----------
List and describe its parameters.

Returns
-------
Describe the kernel that is returned.
"""

def kernel(
rng_key: PRNGKey,
state: SamplingAlgoState,
logdensity_fn: Callable,
*args,
**kwargs,
) -> Tuple[SamplingAlgoState, SamplingAlgoInfo]:
"""Generate a new sample with the sampling kernel."""

# build everything you'll need
proposal_generator = sampling_algorithm_proposal(...)

# generate pseudorandom keys
key_other, key_proposal = jax.random.split(rng_key, 2)

# generate the proposal with all its parts
proposal, info = proposal_generator(key_proposal, ...)
proposal = SamplingAlgoState(...)

return proposal, info

return kernel


class sampling_algorithm:
"""Implements the (basic) user interface for your sampling kernel.

Describe in detail the inner mechanism of the algorithm and its use.

Example
-------
Illustrate the use of the algorithm.

Parameters
----------
List and describe its parameters.

Returns
-------
A ``MCMCSamplingAlgorithm``.
"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
*args,
**kwargs,
) -> SamplingAlgorithm:
kernel = cls.build_kernel(...)

def init_fn(position: ArrayLikeTree):
return cls.init(position, logdensity_fn, ...)

def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
logdensity_fn,
...,
)

return SamplingAlgorithm(init_fn, step_fn)


# and other functions that help make `init` and/or `build_kernel` easier to read and understand
def sampling_algorithm_proposal(*args, **kwags) -> Callable:
"""Title

Description

Parameters
----------
List and describe its parameters.

Returns
-------
Describe what is returned.
"""
# as an example, a Metropolis-Hastings step with symmetric a symmetric transition would look like this:
acceptance_ratio = proposal.safe_energy_diff
sample_proposal = proposal.static_binomial_sampling

def generate(rng_key, state):
# propose a new sample
proposal_state = ...

# accept or reject the proposed sample
initial_energy = ...
proposal_energy = ...
new_proposal, is_diverging = acceptance_ratio(initial_energy, proposal_energy)
sampled_state, info = sample_proposal(rng_key, proposal, new_proposal)

# maybe add to the returned state and collect more useful information
sampled_state, info = ...

return sampled_state, info

return generate
Loading
Loading