Skip to content

Commit

Permalink
Merge remote to fork (#1)
Browse files Browse the repository at this point in the history
* Update README.md (blackjax-devs#638)

* Update README.md

Update citation.

* Update README.md

* Indexing the notebook showing how to reproduce the GIF. (blackjax-devs#640)

Co-authored-by: Junpeng Lao <[email protected]>

* Bump python version (blackjax-devs#645)

* Bump python version

* update bool inverse

* SMC: allow each mutation kernel to have different parameters. (blackjax-devs#649)

* vmaping over parameters in base

* switch from mcmc_factory to just passing in parameters

* pre-commit and typing

* CRU and docs improvement

* pre-commit

* code review updates

* pre-commit

* rename test

* Migrate from deprecated `host_callback` to `io_callback` (blackjax-devs#651)

* Migrate from deprecated `host_callback` to `io_callback`

Co-Authored-By:
George Necula <[email protected]>

* Format file

* Fix bug

* Fix MALA transition energy (blackjax-devs#653)

* Fix MALA transition energy

* Use a different logic.

* Change variable names (blackjax-devs#654)

* Replace iterative RNG split and carry with `jax.random.fold_in` (blackjax-devs#656)

* Replace iterative RNG split and carry with `jax.random.fold_in`

* revert unintended change

* file formatting

* change `jax.tree_map` to `jax.tree.map`

* revert unintended file

* fiddle with rng_key

* seed again

* Removal of Algorithm classes. (blackjax-devs#657)

* more

* removing export

* removal of classes, tests passing

* linter

* fix on test

* linter

* removing parametrization on test

* code review updates

* exporting as_top_level_api in dynamic_hmc

* linter

* code review update: replace imports

* Fix deprecated call to jnp.clip (blackjax-devs#664)

* Update jax version requirements (blackjax-devs#666)

Fix blackjax-devs#665

* Make tests pass on `aarch64-linux` (blackjax-devs#671)

* Enable fitlering of AdaptationInfo (blackjax-devs#674)

* enable AdaptationInfo filtering

* revert progress_bar

* fix pre-commit

* fix empty sets

* enable adapt info filtering for all adaptation algorithms

* fix precommit /progressbar=True

* change filter tuple to use tree_map

* Update `run_inference_algorithm` to split `initial_position` and `initial_state` (blackjax-devs#672)

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* ADD INITIAL_POSITION

* FIX TEST

* RENAME O

* FIX DOCSTRING

* PUT EXPECTATION AFTER TRANSFORM

* Preconditioned mclmc (blackjax-devs#673)

* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* GITIGNORE

* PRECOMMIT CLEAN UP

* ADD INITIAL_POSITION

* FIX TEST

* ADD TEST

* REMOVE BENCHMARKS

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* RENAME O

* UPDATE STREAMING AVG

* UPDATE PR

* RENAME STD_MAT

* New integrator, and add some metadata to integrators.py (blackjax-devs#681)

* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* GITIGNORE

* PRECOMMIT CLEAN UP

* FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS

* TEMPORARILY ADD BENCHMARKS

* ADD INITIAL_POSITION

* FIX TEST

* CLEAN UP

* REMOVE BENCHMARKS

* ADD TEST

* REMOVE BENCHMARKS

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* ADD OMELYAN TEST

* RENAME O

* UPDATE STREAMING AVG

* UPDATE PR

* RENAME STD_MAT

* MERGE MAIN

* REMOVE COEFFICIENT EXPORTS

* Minor formatting (blackjax-devs#685)

* Minor formatting

* formatting

* fix test

* formatting

* MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (blackjax-devs#687)

* FIX KWARG BUG (blackjax-devs#686)

* FIX KWARG BUG

* FIX KWARG BUG

* Change isokinetic_integrator generation API (blackjax-devs#689)

* Apply function on pytree directly. (blackjax-devs#692)

* Apply function on pytree directly.

Avoiding unnecssary unpacking

* Fix kwarg

* Fix sampling test. (blackjax-devs#693)

* Enable shared mcmc parameters with tempered smc (blackjax-devs#694)

* add parameter filtering

* fix parameter split + docstring

* change extend_paramss

* convert to bit twiddling (blackjax-devs#696)

* Remove nightly release (blackjax-devs#699)

* Fix doc mistakes (blackjax-devs#701)

* Fix equation formatting

* Clarify JAX gradient error

* Fix punctuation + capitalization

* Fix grammar

Should not begin sentence with "i.e." in English.

* Fix math formatting error

* Fix typo

Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation.

* Add SVGD citation to appear in doc

Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation.

To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring.

* Fix grammar + clarify doc

* Fix typo

---------

Co-authored-by: Junpeng Lao <[email protected]>

* Update index.md (blackjax-devs#711)

The jitted step remained unused, leading to the example running with an uncompiled nuts.step. 

Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed.

* Enable progress bar under pmap (blackjax-devs#712)

* enable pmap progbar

* fix bar creation

* add locking

* fix formatting

* switch to using chain state

* remove labels (blackjax-devs#716)

* Simplify `run_inference_algorithm` (blackjax-devs#714)

* fix minor type errors

* storing only expectation values

* fixed memory efficient sampling

* clean up

* renaming vars

* precommit fixes

* fixing tests

* fixing tests

* fixing tests

* fixing tests

* fixing tests

* merge main

* burn in and fix tests

* burn in and fix tests

* minor fixes

* minor fixes

* minor fixes

---------

Co-authored-by: [email protected] <[email protected]>

* Harmonize Quickstart example (blackjax-devs#717)

* Update README.md (blackjax-devs#719)

---------

Co-authored-by: Junpeng Lao <[email protected]>
Co-authored-by: Carlos Iguaran <[email protected]>
Co-authored-by: ksnxr <[email protected]>
Co-authored-by: Gaétan Lepage <[email protected]>
Co-authored-by: Alberto Cabezas <[email protected]>
Co-authored-by: andrewdipper <[email protected]>
Co-authored-by: Reuben <[email protected]>
Co-authored-by: Gilad Turok <[email protected]>
Co-authored-by: johannahaffner <[email protected]>
Co-authored-by: [email protected] <[email protected]>
  • Loading branch information
11 people authored Aug 14, 2024
1 parent 2e7f024 commit bbe5c15
Show file tree
Hide file tree
Showing 67 changed files with 2,172 additions and 1,585 deletions.
48 changes: 0 additions & 48 deletions .github/workflows/nightly.yml

This file was deleted.

4 changes: 2 additions & 2 deletions .github/workflows/publish_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:
with:
persist-credentials: false

- name: Set up Python 3.9
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11

- name: Build the documentation with Sphinx
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11
- name: Build sdist and wheel
run: |
python -m pip install -U pip
Expand Down Expand Up @@ -51,7 +51,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11
- name: Give PyPI some time to update the index
run: sleep 240
- name: Attempt install from PyPI
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11
- uses: pre-commit/[email protected]

test:
Expand All @@ -24,7 +24,7 @@ jobs:
- style
strategy:
matrix:
python-version: [ '3.9', '3.11']
python-version: ['3.11', '3.12']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
26 changes: 11 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ or via conda-forge:
conda install -c conda-forge blackjax
```

Nightly builds (bleeding edge) of Blackjax can also be installed using `pip`:

```bash
pip install blackjax-nightly
```

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the
version of JAX that will be installed along with BlackJAX will make your code
run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow
Expand Down Expand Up @@ -81,9 +75,10 @@ state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
for _ in range(100):
rng_key, nuts_key = jax.random.split(rng_key)
state, _ = nuts.step(nuts_key, state)
step = jax.jit(nuts.step)
for i in range(100):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = step(nuts_key, state)
```

See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.
Expand Down Expand Up @@ -138,12 +133,13 @@ Please follow our [short guide](https://github.com/blackjax-devs/blackjax/blob/m
To cite this repository:

```
@software{blackjax2020github,
author = {Cabezas, Alberto, Lao, Junpeng, and Louf, R\'emi},
title = {{B}lackjax: A sampling library for {JAX}},
url = {http://github.com/blackjax-devs/blackjax},
version = {<insert current release tag>},
year = {2023},
@misc{cabezas2024blackjax,
title={BlackJAX: Composable {B}ayesian inference in {JAX}},
author={Alberto Cabezas and Adrien Corenflos and Junpeng Lao and Rémi Louf},
year={2024},
eprint={2402.10797},
archivePrefix={arXiv},
primaryClass={cs.MS}
}
```
In the above bibtex entry, names are in alphabetical order, the version number should be the last tag on the `main` branch.
Expand Down
188 changes: 140 additions & 48 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,163 @@
import dataclasses
from typing import Callable

from blackjax._version import __version__

from .adaptation.chees_adaptation import chees_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc.barker import barker_proposal
from .mcmc.dynamic_hmc import dynamic_hmc
from .mcmc.elliptical_slice import elliptical_slice
from .mcmc.ghmc import ghmc
from .mcmc.hmc import hmc
from .mcmc.mala import mala
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
from .mcmc.mclmc import mclmc
from .mcmc.nuts import nuts
from .mcmc.periodic_orbital import orbital_hmc
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
from .mcmc.rmhmc import rmhmc
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import elliptical_slice as _elliptical_slice
from .mcmc import ghmc as _ghmc
from .mcmc import hmc as _hmc
from .mcmc import mala as _mala
from .mcmc import marginal_latent_gaussian
from .mcmc import mclmc as _mclmc
from .mcmc import nuts as _nuts
from .mcmc import periodic_orbital, random_walk
from .mcmc import rmhmc as _rmhmc
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk
from .mcmc.random_walk import (
irmh_as_top_level_api,
normal_random_walk,
rmh_as_top_level_api,
)
from .optimizers import dual_averaging, lbfgs
from .sgmcmc.csgld import csgld
from .sgmcmc.sghmc import sghmc
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.inner_kernel_tuning import inner_kernel_tuning
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
from .vi.schrodinger_follmer import schrodinger_follmer
from .vi.svgd import svgd
from .sgmcmc import csgld as _csgld
from .sgmcmc import sghmc as _sghmc
from .sgmcmc import sgld as _sgld
from .sgmcmc import sgnht as _sgnht
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
from .vi import svgd as _svgd
from .vi.pathfinder import PathFinderAlgorithm

"""
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower
level to be mostly functional programming in nature and reducing boilerplate code.
"""


@dataclasses.dataclass
class GenerateSamplingAPI:
differentiable: Callable
init: Callable
build_kernel: Callable

def __call__(self, *args, **kwargs) -> SamplingAlgorithm:
return self.differentiable(*args, **kwargs)

def register_factory(self, name, callable):
setattr(self, name, callable)


@dataclasses.dataclass
class GenerateVariationalAPI:
differentiable: Callable
init: Callable
step: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> VIAlgorithm:
return self.differentiable(*args, **kwargs)


@dataclasses.dataclass
class GeneratePathfinderAPI:
differentiable: Callable
approximate: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> PathFinderAlgorithm:
return self.differentiable(*args, **kwargs)


def generate_top_level_api_from(module):
return GenerateSamplingAPI(
module.as_top_level_api, module.init, module.build_kernel
)


# MCMC
hmc = generate_top_level_api_from(_hmc)
nuts = generate_top_level_api_from(_nuts)
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh)
irmh = GenerateSamplingAPI(
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh
)
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc)
rmhmc = generate_top_level_api_from(_rmhmc)
mala = generate_top_level_api_from(_mala)
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian)
orbital_hmc = generate_top_level_api_from(periodic_orbital)

additive_step_random_walk = GenerateSamplingAPI(
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step
)

additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)

mclmc = generate_top_level_api_from(_mclmc)
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
ghmc = generate_top_level_api_from(_ghmc)
barker_proposal = generate_top_level_api_from(barker)

hmc_family = [hmc, nuts]

# SMC
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)

smc_family = [tempered_smc, adaptive_tempered_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
sgld = generate_top_level_api_from(_sgld)
sghmc = generate_top_level_api_from(_sghmc)
sgnht = generate_top_level_api_from(_sgnht)
csgld = generate_top_level_api_from(_csgld)
svgd = generate_top_level_api_from(_svgd)

# variational inference
meanfield_vi = GenerateVariationalAPI(
_meanfield_vi.as_top_level_api,
_meanfield_vi.init,
_meanfield_vi.step,
_meanfield_vi.sample,
)
schrodinger_follmer = GenerateVariationalAPI(
_schrodinger_follmer.as_top_level_api,
_schrodinger_follmer.init,
_schrodinger_follmer.step,
_schrodinger_follmer.sample,
)

pathfinder = GeneratePathfinderAPI(
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample
)


__all__ = [
"__version__",
"dual_averaging", # optimizers
"lbfgs",
"hmc", # mcmc
"dynamic_hmc",
"rmhmc",
"mala",
"mgrad_gaussian",
"nuts",
"orbital_hmc",
"additive_step_random_walk",
"rmh",
"irmh",
"mclmc",
"elliptical_slice",
"ghmc",
"barker_proposal",
"sgld", # stochastic gradient mcmc
"sghmc",
"sgnht",
"csgld",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
"chees_adaptation",
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"adaptive_tempered_smc", # smc
"tempered_smc",
"inner_kernel_tuning",
"meanfield_vi", # variational inference
"pathfinder",
"schrodinger_follmer",
"svgd",
"ess", # diagnostics
"rhat",
]
Loading

0 comments on commit bbe5c15

Please sign in to comment.