This repo contains code supporting a series of blog posts I'm currently writing. Start at Part 1: the basics.
This repo contains code for MCMC-based fully Bayesian inference for a logistic regression model using R, Python, Scala, Haskell, Dex, and C, using bespoke hand-coded samplers (random walk Metropolis, unadjusted Langevin algorithm, MALA, and HMC), and samplers constructed with the help of libraries such as JAGS, Stan, JAX, BlackJAX, NumPyro, PyMC3, and Spark.
I intend to add similar examples using one or two other libraries. At some point I'd also like to switch to a much bigger dataset, that better illustrates some of the scalability issues of the different languages and libraries.
Here we will conduct fully Bayesian inference for the typical Bayesian logistic regression model for a binary outcome based on some covariates. The $i$th observation will be 1 with probability
where
JAX can auto-diff likelihoods like this, but for comparison purposes, we can also use hard-coded gradients for MALA and HMC:
For a fully Bayesian analysis, we also need a prior distribution. Here we will assume independent normal priors on the elements of
We will be analysing the "Pima" training dataset, with 200 observations and 7 predictors. Including an intercept as the first covariate gives a parameter vector of length
Please read: The code in this repo should not be used for any kind of serious performance or benchmarking exercise. I have deliberately tried to use a reasonably consistent simple style of implementation across all of the languages. I have not made any attempt to optimise any of the implementations. Indeed, I have deliberately chosen not to optimise any of the implementations. Clearly all of the implementations could be optimised, and the nature of the optimisation would differ greatly between languages. Moreover, benchmarking on a small toy dataset such as the one considered here would be very uninteresting. The interesting scaling issues only become apparent on larger datasets.
Note that these scripts use pacman to download and install any missing dependencies.
- create-dataset.R - we will use the infamous
MASS::Pima.tr
dataset, exported from R in parquet format (rather than CSV, as it's now the 21st Century, but also save in a simple text format for languages that can't easily read parquet...). - fit-glm.R - kick-off with a simple GLM fit in R for sanity-checking purposes.
- fit-bayes.R - MAP, followed by a Random walk Metropolis MCMC sampler in R.
- fit-ul.R - Unadjusted Langevin in R (with a simple diagonal pre-conditioner). Note that this algorithm is approximate, so we wouldn't expect it to match up perfectly with the exact sampling methods.
- fit-mala.R - MALA in R (with a diagonal pre-conditioner).
- fit-hmc.R - HMC in R (with a diagonal mass-matrix).
- fit-rjags.R - Fit using rjags. Note that this script probably won't work unless a site-wide installation of JAGS is available.
- fit-rstan.R - Fit using rstan.
These scripts assume a Python installation with NumPy and SciPy. The later scripts require JAX. The BlackJAX scripts require BlackJAX, the NumPyro script requires NumPyro, and the PyMC3 script requires PyMC3. These can be pip
installed for basic use. See the websites for more detailed information.
- fit-numpy.py - Random walk MH with NumPy.
- fit-np-ul.py - Unadjusted Langevin with NumPy (approximate).
- fit-np-mala.py - MALA with NumPy.
- fit-np-hmc.py - HMC with NumPy.
- fit-jax.py - RM MH with log posterior and MH kernel in JAX, but main MCMC loop in python.
- fit-jax2.py - As above, but with main MCMC loop in JAX (much faster).
- fit-jax-ul.py - JAX for unadjusted Langevin (with a diagonal pre-conditioner). Note that this is an approximate algorithm. Note also that JAX AD is being used for gradients.
- fit-jax-mala.py - JAX for MALA (with a diagonal pre-conditioner). JAX AD for gradients.
- fit-jax-hmc.py - JAX for HMC (with a diagonal mass-matrix). JAX AD for gradients.
- fit-blackjax.py - RW MH using BlackJAX.
- fit-blackjax-mala.py - MALA with BlackJAX. Note that the MALA kernel in BlackJAX doesn't seem to allow a pre-conditioner, so a huge thinning interval is used here to get vaguely reasonable results.
- fit-blackjax-nuts.py - NUTS sampler from BlackJAX.
- fit-numpyro.py - NUTS sampler from NumPyro.
- fit-pymc3.py - NUTS sampler from (old) PyMC3.
- fit-pymc.py - NUTS sampler from (new) PyMC (ie. PyMC (3), version >= 4.0).
The Scala examples just require a recent JVM and sbt. sbt
will look after other dependencies (including Scala itself). See the Readme in the Scala directory for further info.
- fit-bayes.scala - Random walk MH with Scala and Breeze.
- fit-nopar.scala - Random walk MH, re-factored to make it easy to run in parallel, but still serial.
- fit-par.scala - Random walk MH, running in parallel on all available CPU cores. Note that the evaluation of the log-likelihood is parallelised over observations, but due to the very small size of this dataset, this version runs considerably slower than the previous version. For large datasets it will be a different story.
- fit-ul.scala - Unadjusted Langevin with Breeze (approximate).
- fit-mala.scala - MALA with Breeze.
- fit-hmc.scala - HMC with Breeze.
The Spark example requires a Spark installation in addition to sbt
. See the Readme in the Scala directory for further info.
- fit-spark.scala - RW MH, with Spark being used to distribute the log-likelihood evaluation over a cluster. Note that this code runs very slowly, as the overheads associated with distributing the computation dominate for very small datasets like the one used here. The thinning interval has been reduced so that the job completes in reasonable time.
The Haskell examples use stack to build and run and manage dependencies. See the readme in the Haskell/lr
directory for further details.
- Rwmh.hs - Random walk MH in Haskell, using a stateful monadic random number generator.
- RwmhP.hs - Random walk MH in Haskell, using a pure random number generator explicity threaded through the code.
- RwmhPS.hs - Random walk MH in Haskell, using a pure random number generator together with a splitting approach, a la JAX and Dex.
- Mala.hs - MALA in Haskell (using a stateful monadic random number generator).
- Hmc.hs - HMC in Haskell (using a stateful monadic random number generator).
The Dex examples rely only on a basic Dex installation. See the readme in the Dex
directory for further details. Note that Dex is an early-stage research project lacking many of the tools and libraries one would normally expect. It's also rather lacking documentation. However, it's interesting, pure functional, strongly typed, differentiable, and fast.
- fit-bayes.dx - Random walk MH in Dex. Dex uses a splittable random number generator, similar to JAX. It's not quite as fast as JAX, but faster than anything else I've tried, including my C code.
- fit-ul.dx - Unadjusted Langevin in Dex (approximate), with hard-coded gradients.
- fit-ul-ad.dx - Unadjusted Langevin in Dex (approximate), with auto-differentiated gradients.
- fit-mala.dx - MALA in Dex, with hard-coded gradients.
- fit-mala-ad.dx - MALA in Dex, with auto-differentiated gradients. Roughly two to three times slower than using hard-coded gradients, which seems reasonable.
- fit-hmc.dx - HMC in Dex, with hard-coded gradients.
- fit-hmc-ad.dx - HMC in Dex, with auto-differentiated gradients. Again, 2-3 times slower than using hard-coded gradients. But still very fast.
The C examples assume a Unix-like development environment. See the Readme in the C directory for further info.
- fit-bayes.c - Random walk MH with C and the GSL. The code isn't pretty, but it's fast (in particular, there are no allocations in the main MCMC loop). But still not as fast as JAX, even on a single core.
Copyright (C) 2022, Darren J Wilkinson, but released under a GPL-3.0 license