From 4d6089e06ee2349909c4f8ca1ceff7e243bdc5cf Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Fri, 4 Oct 2024 12:33:32 -0300 Subject: [PATCH] fixes on comments --- blackjax/smc/from_mcmc.py | 4 +- blackjax/smc/partial_posteriors_path.py | 53 ++++++++++++------------- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/blackjax/smc/from_mcmc.py b/blackjax/smc/from_mcmc.py index 41546a308..0e60b5968 100644 --- a/blackjax/smc/from_mcmc.py +++ b/blackjax/smc/from_mcmc.py @@ -15,8 +15,8 @@ def build_kernel( update_strategy: Callable = update_and_take_last, ): """SMC step from MCMC kernels. - Builds MCMC kernels from the input parameters, which may change across iterations. - Moreover, it defines the way such kernels are used to update the particles. This layer + Builds MCMC kernels from the input parameters, which may change across iterations. + Moreover, it defines the way such kernels are used to update the particles. This layer adapts an API defined in terms of kernels (mcmc_step_fn and mcmc_init_fn) into an API that depends on an update function over the set of particles. Returns diff --git a/blackjax/smc/partial_posteriors_path.py b/blackjax/smc/partial_posteriors_path.py index 753d00247..2381152f4 100644 --- a/blackjax/smc/partial_posteriors_path.py +++ b/blackjax/smc/partial_posteriors_path.py @@ -17,8 +17,8 @@ class PartialPosteriorsSMCState(NamedTuple): weights: Weights of the particles, so that they represent a probability distribution selector: - {Datapoints used to calculate the posterior the particles represent - + Datapoints used to calculate the posterior the particles represent, a 1D boolean + array to indicate which datapoints to include in the computation of the observed likelihood. """ particles: ArrayTree @@ -27,8 +27,7 @@ class PartialPosteriorsSMCState(NamedTuple): def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState: - """ - num_datapoints are the number of observations that could potentially be + """num_datapoints are the number of observations that could potentially be used in a partial posterior. Since the initial selector is all 0s, it means that no likelihood term will be added (only prior). """ @@ -49,28 +48,27 @@ def build_kernel( """Build the Partial Posteriors (data tempering) SMC kernel. The distribution's trajectory includes increasingly adding more datapoints to the likelihood. See Section 2.2 of https://arxiv.org/pdf/2007.11936 - Parameters - ---------- - mcmc_step_fn - A function that computes the log density of the prior distribution - mcmc_init_fn - A function that returns the probability at a given - position. - resampling_fn - A random function that resamples generated particles based of weights - num_mcmc_steps - Number of iterations in the MCMC chain. - mcmc_parameters - A dictionary of parameters to be used by the inner MCMC kernels - partial_logposterior_factory: - A callable that given an array of 0 and 1, returns a function logposterior(x). - The array represents which values to include in the logposterior calculation. The logposterior - must be jax compilable. - - Returns - ------- - A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for - the current and previous posteriors, and takes a data-tempered SMC state. + Parameters + ---------- + mcmc_step_fn + A function that computes the log density of the prior distribution + mcmc_init_fn + A function that returns the probability at a given position. + resampling_fn + A random function that resamples generated particles based of weights + num_mcmc_steps + Number of iterations in the MCMC chain. + mcmc_parameters + A dictionary of parameters to be used by the inner MCMC kernels + partial_logposterior_factory: + A callable that given an array of 0 and 1, returns a function logposterior(x). + The array represents which values to include in the logposterior calculation. The logposterior + must be jax compilable. + + Returns + ------- + A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for + the current and previous posteriors, and takes a data-tempered SMC state. """ delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) @@ -102,8 +100,7 @@ def as_top_level_api( partial_logposterior_factory: Callable, update_strategy=update_and_take_last, ) -> SamplingAlgorithm: - """ - A factory that wraps the kernel into a SamplingAlgorithm object. + """A factory that wraps the kernel into a SamplingAlgorithm object. See build_kernel for full documentation on the parameters. """