-
Hello all, Thanks for writing this library, it is really great. I am trying to implement some Metropolis-within-Gibbs kind of sampler where I sample one set of variables using an elliptical slice sampler and the other set of variables using an RMH sampler. I am not sure if this can be done nicely since afaics the step function (or likelihood fn) for each set of variables would need to have an argument for the other kernel's newly sampled positions (or something like that), correct? Is there any way how this could be achieved idiomatically? Thanks, |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 13 replies
-
You could partially apply the if you give me a simple example I can maybe show you how to build a kernel that performs the two gibbs sampling steps. |
Beta Was this translation helpful? Give feedback.
-
@rlouf could you please confirm that the code below is the intended implementation of # example of a Gibbs kernel on two blocks of RMH kernels
import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random
import blackjax
import pandas as pd
import seaborn as sns
def logprob_fn(x, y, Sigma):
"""
Log-pdf of (x, y) ~ Normal(0, Sigma)
"""
z = jnp.concatenate([x, y])
return jsp.stats.multivariate_normal.logpdf(z, mean=jnp.zeros_like(z), cov=Sigma)
# example with x.shape == y.shape == (2,)
Sigma = jnp.array([
[1., 0., .8, 0.],
[0., 1., 0., .8],
[.8, 0., 1., 0.],
[0., .8, 0., 1.]
])
def logprob(params):
return logprob_fn(Sigma=Sigma, **params)
# initial state
initial_position = {"x": jnp.array([0., 0.]), "y": jnp.array([.1, .1])}
initial_state = blackjax.mcmc.rmh.init(initial_position, logprob)
# now the gibbs kernel function
rmh_step_fn = blackjax.mcmc.rmh.kernel()
rmh_state_ctor = blackjax.mcmc.rmh.RMHState
def gibbs_kernel(rng_key, state, sigma_x, sigma_y):
"""
Gibbs kernel on x and y each by RMH with covariances sigma_x and sigma_y.
"""
# setup
key_x, key_y = jax.random.split(rng_key, num=2)
position = {"x": state.position["x"], "y": state.position["y"]}
log_probability = state.log_probability
# x step
# need to create a partial state for updating just x,
# otherwise rmh_step_fn infers wrong dimension for proposal
state_x = rmh_state_ctor(position["x"], log_probability)
def logprob_x(x): return logprob({"x": x, "y": position["y"]})
state_x, _ = rmh_step_fn(
rng_key=key_x,
state=state_x,
logprob_fn=logprob_x,
sigma=sigma_x
)
position["x"] = state_x.position
# need a common log_probability for the draws,
# otherwise could have written gibbs_kernel
# with separate arguments state_x and state_y
log_probability = state_x.log_probability
# y step
state_y = rmh_state_ctor(position["y"], log_probability)
def logprob_y(y): return logprob({"x": position["x"], "y": y})
state_y, _ = rmh_step_fn(
rng_key=key_y,
state=state_y,
logprob_fn=logprob_y,
sigma=sigma_y
)
position["y"] = state_y.position
log_probability = state_y.log_probability
return rmh_state_ctor(position, log_probability)
# sample from the posterior
sigma_x = .2 * jnp.eye(2)
sigma_y = .3 * jnp.eye(2)
def inference_loop(rng_key, initial_state, num_samples, sigma_x, sigma_y):
@jax.jit
def one_step(state, rng_key):
state = gibbs_kernel(rng_key, state, sigma_x, sigma_y)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
rng_key = jax.random.PRNGKey(0)
states = inference_loop(
rng_key = rng_key,
initial_state = initial_state,
num_samples = 10000,
sigma_x = sigma_x,
sigma_y = sigma_y
)
# plot samples
data = pd.DataFrame({
"x1": states.position["x"][:,0],
"x2": states.position["x"][:,1],
"y1": states.position["y"][:,0],
"y2": states.position["y"][:,1]
})
sns.pairplot(data, kind="hist") |
Beta Was this translation helpful? Give feedback.
-
Great. Are you OK with me trying to turn this into a PR for the corresponding documentation item #328? |
Beta Was this translation helpful? Give feedback.
-
Delighted to get the 👍 on this! Before I jump in I was hoping you would weigh in on a design consideration. The basic Gibbs framework above is:
My question has to do with how to address step 2. I can think of two ways to do this:
The advantage of the first option is that it is simpler to implement and more standardized, i.e., doesn't require an understanding of the sampler's implementation to use. The advantage of the second option is that it avoids computing logposteriors twice. However, this is at the expense of the user understanding the algorithm internals well enough to put the just-computed logposterior in the right place. If you don't, the MCMC could converge to the wrong value and you'd having little way of diagnosing this on your specific application in the wild. For these reasons I'm leaning towards documenting option 1, with option 2 left for users to figure out for themselves. What do you think? |
Beta Was this translation helpful? Give feedback.
You could partially apply the
logprob_fn
at each step so that each algorithm doesn’t need to be aware of the parameters other than the ones for which It generates a new sample.if you give me a simple example I can maybe show you how to build a kernel that performs the two gibbs sampling steps.