A numpyro implementation of autoStep methods
pip install "autostep @ git+ssh://[email protected]/UBC-Stat-ML/autostep.git"
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC
from autostep.autohmc import AutoMALA
from autostep import utils
# define model
J = 8
y = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools(J, sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(0, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
with numpyro.plate('J', J):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
# instantiate sampler and run
n_rounds = 16
n_warmup, n_keep = utils.split_n_rounds(n_rounds) # translate rounds to warmup/keep
kernel = AutoMALA(eight_schools) # default: symmetric selector, (log-)random mix preconditioner
mcmc = MCMC(kernel, num_warmup=n_warmup, num_samples=n_keep)
mcmc.run(random.key(9), J, sigma, y=y)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mu 6.55 3.00 6.61 2.08 12.07 13.91 1.12
tau 2.95 2.51 2.31 0.04 6.30 37.44 1.03
theta[0] 7.56 4.39 7.44 -0.89 14.19 14.30 1.20
theta[1] 7.28 3.96 7.12 0.34 13.56 33.12 1.12
theta[2] 6.71 3.93 6.47 0.07 13.04 45.15 1.03
theta[3] 5.83 4.72 6.52 -1.56 12.87 20.07 1.14
theta[4] 6.16 4.03 6.46 -0.71 12.64 30.49 1.05
theta[5] 7.23 4.42 6.93 -1.63 13.64 25.33 1.06
theta[6] 7.63 3.65 7.45 1.46 13.59 23.19 1.13
theta[7] 7.28 3.86 7.23 0.60 12.96 34.00 1.09
- Jittered step sizes
Biron-Lattes, M., Surjanovic, N., Syed, S., Campbell, T., & Bouchard-Côté, A.. (2024). autoMALA: Locally adaptive Metropolis-adjusted Langevin algorithm. Proceedings of The 27th International Conference on Artificial Intelligence and Statistics, in Proceedings of Machine Learning Research 238:4600-4608.
Liu, T., Surjanovic, N., Biron-Lattes, M., Bouchard-Côté, A., & Campbell, T. (2024). AutoStep: Locally adaptive involutive MCMC. arXiv preprint arXiv:2410.18929.