Skip to content

Commit

Permalink
fixes on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran committed Oct 4, 2024
1 parent 34d5ba3 commit 4d6089e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 30 deletions.
4 changes: 2 additions & 2 deletions blackjax/smc/from_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def build_kernel(
update_strategy: Callable = update_and_take_last,
):
"""SMC step from MCMC kernels.
Builds MCMC kernels from the input parameters, which may change across iterations.
Moreover, it defines the way such kernels are used to update the particles. This layer
Builds MCMC kernels from the input parameters, which may change across iterations.
Moreover, it defines the way such kernels are used to update the particles. This layer
adapts an API defined in terms of kernels (mcmc_step_fn and mcmc_init_fn) into an API
that depends on an update function over the set of particles.
Returns
Expand Down
53 changes: 25 additions & 28 deletions blackjax/smc/partial_posteriors_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class PartialPosteriorsSMCState(NamedTuple):
weights:
Weights of the particles, so that they represent a probability distribution
selector:
{Datapoints used to calculate the posterior the particles represent
Datapoints used to calculate the posterior the particles represent, a 1D boolean
array to indicate which datapoints to include in the computation of the observed likelihood.
"""

particles: ArrayTree
Expand All @@ -27,8 +27,7 @@ class PartialPosteriorsSMCState(NamedTuple):


def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState:
"""
num_datapoints are the number of observations that could potentially be
"""num_datapoints are the number of observations that could potentially be
used in a partial posterior. Since the initial selector is all 0s, it
means that no likelihood term will be added (only prior).
"""
Expand All @@ -49,28 +48,27 @@ def build_kernel(
"""Build the Partial Posteriors (data tempering) SMC kernel.
The distribution's trajectory includes increasingly adding more
datapoints to the likelihood. See Section 2.2 of https://arxiv.org/pdf/2007.11936
Parameters
----------
mcmc_step_fn
A function that computes the log density of the prior distribution
mcmc_init_fn
A function that returns the probability at a given
position.
resampling_fn
A random function that resamples generated particles based of weights
num_mcmc_steps
Number of iterations in the MCMC chain.
mcmc_parameters
A dictionary of parameters to be used by the inner MCMC kernels
partial_logposterior_factory:
A callable that given an array of 0 and 1, returns a function logposterior(x).
The array represents which values to include in the logposterior calculation. The logposterior
must be jax compilable.
Returns
-------
A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for
the current and previous posteriors, and takes a data-tempered SMC state.
Parameters
----------
mcmc_step_fn
A function that computes the log density of the prior distribution
mcmc_init_fn
A function that returns the probability at a given position.
resampling_fn
A random function that resamples generated particles based of weights
num_mcmc_steps
Number of iterations in the MCMC chain.
mcmc_parameters
A dictionary of parameters to be used by the inner MCMC kernels
partial_logposterior_factory:
A callable that given an array of 0 and 1, returns a function logposterior(x).
The array represents which values to include in the logposterior calculation. The logposterior
must be jax compilable.
Returns
-------
A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for
the current and previous posteriors, and takes a data-tempered SMC state.
"""
delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy)

Expand Down Expand Up @@ -102,8 +100,7 @@ def as_top_level_api(
partial_logposterior_factory: Callable,
update_strategy=update_and_take_last,
) -> SamplingAlgorithm:
"""
A factory that wraps the kernel into a SamplingAlgorithm object.
"""A factory that wraps the kernel into a SamplingAlgorithm object.
See build_kernel for full documentation on the parameters.
"""

Expand Down

0 comments on commit 4d6089e

Please sign in to comment.