Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Nested rhat MCMC diagnostic #752

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .adaptation.window_adaptation import window_adaptation
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import nested_rhat as nested_rhat
from .diagnostics import potential_scale_reduction as rhat
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
Expand Down Expand Up @@ -161,5 +162,6 @@ def generate_top_level_api_from(module):
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"ess", # diagnostics
"nested_rhat",
"rhat",
]
62 changes: 61 additions & 1 deletion blackjax/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from blackjax.types import Array, ArrayLike

__all__ = ["potential_scale_reduction", "effective_sample_size"]
__all__ = ["potential_scale_reduction", "nested_rhat", "effective_sample_size"]


def potential_scale_reduction(
Expand Down Expand Up @@ -75,6 +75,66 @@ def potential_scale_reduction(
return rhat_value.squeeze()


def nested_rhat(
input_array: ArrayLike,
superchain_axis: int = 0,
chain_axis: int = 1,
sample_axis: int = 2,
) -> Array:
"""Margossian et al. (2024)'s nested R-hat for computing multiple MCMC superchain convergence.

Parameters
----------
input_array
An array representing multiple superchains of MCMC smaples. The array must
contain a superchain dimension, chain dimension, and sample dimension.
superchain_axis
The axis indicating the multiple superchains. Default to 0.
chain_axis
The axis indicating the multiple chains. Default to 1.
sample_axis
The axis indicating a single chain of MCMC samples. Default to 2.

Returns
-------
NDArray of the resulting statistics (r-hat), with the chain and sample dimensions squeezed.

"""
assert input_array.ndim == 4, "The input array must have 4 dimensions."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should relax the ndim, as our input could have multiple dimensions of event shape (ie the random variable is non-scaler).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use keepdims=True and it should works.

num_chains = input_array.shape[chain_axis]
num_samples = input_array.shape[sample_axis]
param_axis = 3 - (chain_axis + sample_axis + superchain_axis)
num_params = input_array.shape[param_axis]
assert (
num_chains > 1 or num_samples > 1
), "num_chains or num_samples must be greater than 1 for valid nested R-hat."

chain_means = jnp.mean(input_array, axis=sample_axis)
super_means = jnp.mean(chain_means, axis=chain_axis)
total_mean = jnp.mean(super_means, axis=superchain_axis)

between_var = jnp.mean(jnp.square(super_means - total_mean), axis=superchain_axis)

if num_chains > 1:
within_chain_var = jnp.mean(
jnp.square(chain_means - super_means), axis=chain_axis
)
else:
within_chain_var = jnp.zeros(num_params)

if num_samples > 1:
within_super_var = jnp.mean(
jnp.square(input_array - chain_means), axis=(chain_axis, sample_axis)
)
else:
within_super_var = jnp.zeros(num_params)

within_var = jnp.mean(within_chain_var + within_super_var, axis=superchain_axis)

nested_rhat_value = jnp.sqrt(1 + between_var / within_var)
return nested_rhat_value.squeeze()


def effective_sample_size(
input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1
) -> Array:
Expand Down
Loading