Skip to content

Python implementation of the No-U-Turn Sampler leveraging JAX

License

Notifications You must be signed in to change notification settings

guillaume-plc/jaxnuts

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jaxnuts

Python implementation of the No-U-Turn Sampler from Hoffman and Gelman (Algorithm 6) leveraging JAX.

Usage

Import libraries

import jax
import jax.numpy as jnp
import jax.random as random

from jaxnuts.sampler import NUTS

For low dimensional problems such as this simple example, force JAX to use the CPU (avoid GPU overhead)

jax.config.update('jax_platform_name', 'cpu')

Define a log-probability to sample from

def logprob(x):
  """Standard normal"""
  return -.5 * jnp.dot(x, x)

Generate samples

key = random.PRNGkey(0)
sampler = NUTS(jnp.ones(2), logp=logprob, target_acceptance=.5, M_adapt=1000)
key, samples, step_size = sampler.sample(1000, key)

About

Python implementation of the No-U-Turn Sampler leveraging JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages