Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jcopo committed Oct 1, 2024
1 parent 51a5dc6 commit 407d7c0
Showing 1 changed file with 125 additions and 11 deletions.
136 changes: 125 additions & 11 deletions blackjax/mcmc/diffusive_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import blackjax.util

from typing import Callable, NamedTuple, Tuple

from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax import SamplingAlgorithm

import jax
import jax.numpy as jnp

import blackjax.util
from blackjax import SamplingAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey


class GibbsState(NamedTuple):
"""State of the Diffusive Gibbs algorithm."""

position: ArrayTree
logdensity: float
logdensity_grad: ArrayTree
Expand All @@ -31,13 +32,45 @@ class GibbsState(NamedTuple):
count: int


def noiser(rng_key: PRNGKey, state: GibbsState):
def noiser(rng_key: PRNGKey, state: GibbsState) -> ArrayTree:
"""Generate a noised position based on the current state.
Parameters
----------
rng_key : PRNGKey
The random number generator key.
state : GibbsState
The current state of the Gibbs sampler.
Returns
-------
ArrayTree
The noised position.
"""
position, _, _, noise_contraction, noise_sigma, _ = state
noise = blackjax.util.generate_gaussian_noise(rng_key, position, 0, noise_sigma)
return jax.tree.map(lambda x, n: noise_contraction * x + n, position, noise)


def noiser_logpdf(state: GibbsState, sample_noised: ArrayTree, sample_clean: ArrayTree):
def noiser_logpdf(
state: GibbsState, sample_noised: ArrayTree, sample_clean: ArrayTree
) -> float:
"""Compute the log probability density of the noised sample given the clean sample.
Parameters
----------
state : GibbsState
The current state of the Gibbs sampler.
sample_noised : ArrayTree
The noised sample.
sample_clean : ArrayTree
The clean sample.
Returns
-------
float
The log probability density.
"""
mean = jax.tree.map(jnp.multiply, sample_clean, state.noise_contraction)
return jax.scipy.stats.norm.logpdf(sample_noised, mean, state.noise_sigma**2).sum()

Expand All @@ -47,7 +80,25 @@ def init_denoising(
noised_position: ArrayTree,
state: GibbsState,
logdensity_fn: Callable,
):
) -> ArrayTree:
"""Initialize the denoising process.
Parameters
----------
rng_key : PRNGKey
The random number generator key.
noised_position : ArrayTree
The noised position.
state : GibbsState
The current state of the Gibbs sampler.
logdensity_fn : Callable
The log density function.
Returns
-------
ArrayTree
Position at which to start the denoising process.
"""
position, _, _, noise_contraction, noise_sigma, count = state
noise = blackjax.util.generate_gaussian_noise(
rng_key, noised_position, 0, noise_sigma / noise_contraction
Expand Down Expand Up @@ -92,8 +143,26 @@ def gaussian_term(a, b, scale):


def denoise(
rng_key, position: ArrayLikeTree, denoiser: SamplingAlgorithm, n_steps: int
):
rng_key: PRNGKey, position: ArrayLikeTree, denoiser: SamplingAlgorithm, n_steps: int
) -> ArrayTree:
"""Perform denoising steps.
Parameters
----------
rng_key : PRNGKey
The random number generator key.
position : ArrayLikeTree
The initial position.
denoiser : SamplingAlgorithm
The denoising algorithm.
n_steps : int
The number of denoising steps.
Returns
-------
ArrayTree
The denoised position.
"""
init_state = denoiser.init(position, rng_key)

def body_fn(state, rng_key):
Expand All @@ -107,13 +176,41 @@ def body_fn(state, rng_key):


def build_kernel():
"""Build the Diffusive Gibbs kernel.
Returns
-------
Callable
The Diffusive Gibbs kernel function.
"""

def kernel(
rng_key: PRNGKey,
state: GibbsState,
logdensity_fn: Callable,
n_steps: int,
schedule: Callable[[int], Tuple[float, float]],
):
) -> GibbsState:
"""Generate a new sample with the Diffusive Gibbs kernel.
Parameters
----------
rng_key : PRNGKey
The random number generator key.
state : GibbsState
The current state of the Gibbs sampler.
logdensity_fn : Callable
The log density function.
n_steps : int
The number of denoising steps.
schedule : Callable[[int], Tuple[float, float]]
A function that returns the noise contraction and noise sigma for each step.
Returns
-------
GibbsState
The new state of the Gibbs sampler.
"""
_, _, _, noise_contraction, noise_sigma, count = state
grad_fn = jax.value_and_grad(logdensity_fn)
logdensity, logdensity_grad = grad_fn(state.position)
Expand Down Expand Up @@ -149,6 +246,23 @@ def as_top_level_api(
n_steps: int = 10,
schedule: Callable[[int], Tuple[float, float]] = lambda _: (0.9, 0.1),
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Diffusive Gibbs kernel.
Parameters
----------
logdensity_fn : Callable
The log density function we wish to draw samples from.
n_steps : int, optional
The number of denoising steps, by default 10.
schedule : Callable[[int], Tuple[float, float]], optional
A function that returns the noise contraction and noise sigma for each step,
by default lambda _: (0.9, 0.1).
Returns
-------
SamplingAlgorithm
A ``SamplingAlgorithm`` instance for the Diffusive Gibbs kernel.
"""
kernel = build_kernel()

def init(position: ArrayLikeTree, rng_key=None) -> GibbsState:
Expand Down

0 comments on commit 407d7c0

Please sign in to comment.