Skip to content

Commit

Permalink
Update barker.py
Browse files Browse the repository at this point in the history
Make acceptance function metric agnostic
  • Loading branch information
AdrienCorenflos authored Oct 2, 2024
1 parent 9122476 commit ad59aba
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.metrics import Metric
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.types import ArrayLikeTree, ArrayTree, Numeric, PRNGKey

__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]

Expand Down Expand Up @@ -83,28 +83,34 @@ def build_kernel():

def _compute_acceptance_probability(
state: BarkerState, proposal: BarkerState, metric: Metric
) -> float:
) -> Numeric:
"""Compute the acceptance probability of the Barker's proposal kernel."""

x = state.position
y = proposal.position
log_x = state.logdensity_grad
log_y = proposal.logdensity_grad

z1 = metric.scale(y, y, True, True)
z2 = metric.scale(x, x, True, True)
c_x = metric.scale(x, log_x, False, True)
c_y = metric.scale(y, log_y, False, True)
y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x)
x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x)
z_tilde_x_to_y = metric.scale(x, y_minus_x, True, True)
z_tilde_y_to_x = metric.scale(y, x_minus_y, True, True)

z1_flat, _ = ravel_pytree(z1)
z2_flat, _ = ravel_pytree(z2)
c_x_flat, _ = ravel_pytree(c_x)
c_y_flat, _ = ravel_pytree(c_y)
c_x_to_y = metric.scale(x, log_x, False, True)
c_y_to_x = metric.scale(y, log_y, False, True)

z = z1_flat - z2_flat
z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y)
z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x)

num = _log1pexp(-z * c_x_flat)
denom = _log1pexp(z * c_y_flat)
c_x_to_y_flat, _ = ravel_pytree(c_x_to_y)
c_y_to_x_flat, _ = ravel_pytree(c_y_to_x)

num = metric.kinetic_energy(x_minus_y, y) - _log1pexp(
-z_tilde_y_to_x_flat * c_y_to_x_flat
)
denom = metric.kinetic_energy(y_minus_x, x) - _log1pexp(
-z_tilde_x_to_y_flat * c_x_to_y_flat
)

ratio_proposal = jnp.sum(num - denom)

Expand All @@ -121,7 +127,7 @@ def kernel(
if inverse_mass_matrix is None:
p, _ = ravel_pytree(state.position)
(m,) = p.shape
inverse_mass_matrix = jnp.identity(m)
inverse_mass_matrix = jnp.ones((m,))
metric = metrics.default_metric(inverse_mass_matrix)
grad_fn = jax.value_and_grad(logdensity_fn)
key_sample, key_rmh = jax.random.split(rng_key)
Expand Down Expand Up @@ -259,7 +265,9 @@ def _barker_sample_nd(key, mean, a, scale, metric):
b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape)

# return mean + z if b == 1 else mean - z
return mean + metric.scale(mean, b * z - (1 - b) * z, False, False)
return jax.tree_util.tree_map(
lambda a, b: a + b, mean, metric.scale(mean, b * z - (1 - b) * z, False, False)
)


def _barker_sample(key, mean, a, scale, metric):
Expand All @@ -283,7 +291,7 @@ def _barker_sample(key, mean, a, scale, metric):

flat_mean, unravel_fn = ravel_pytree(mean)
flat_a, _ = ravel_pytree(a)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale, metric)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale)
return unravel_fn(flat_sample)


Expand Down

0 comments on commit ad59aba

Please sign in to comment.