Skip to content

Commit

Permalink
Removal of Algorithm classes. (#657)
Browse files Browse the repository at this point in the history
* more

* removing export

* removal of classes, tests passing

* linter

* fix on test

* linter

* removing parametrization on test

* code review updates

* exporting as_top_level_api in dynamic_hmc

* linter

* code review update: replace imports
  • Loading branch information
ciguaran authored Apr 22, 2024
1 parent a5f7482 commit 1bc6f93
Show file tree
Hide file tree
Showing 30 changed files with 893 additions and 899 deletions.
188 changes: 140 additions & 48 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,163 @@
import dataclasses
from typing import Callable

from blackjax._version import __version__

from .adaptation.chees_adaptation import chees_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc.barker import barker_proposal
from .mcmc.dynamic_hmc import dynamic_hmc
from .mcmc.elliptical_slice import elliptical_slice
from .mcmc.ghmc import ghmc
from .mcmc.hmc import hmc
from .mcmc.mala import mala
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
from .mcmc.mclmc import mclmc
from .mcmc.nuts import nuts
from .mcmc.periodic_orbital import orbital_hmc
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
from .mcmc.rmhmc import rmhmc
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import elliptical_slice as _elliptical_slice
from .mcmc import ghmc as _ghmc
from .mcmc import hmc as _hmc
from .mcmc import mala as _mala
from .mcmc import marginal_latent_gaussian
from .mcmc import mclmc as _mclmc
from .mcmc import nuts as _nuts
from .mcmc import periodic_orbital, random_walk
from .mcmc import rmhmc as _rmhmc
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk
from .mcmc.random_walk import (
irmh_as_top_level_api,
normal_random_walk,
rmh_as_top_level_api,
)
from .optimizers import dual_averaging, lbfgs
from .sgmcmc.csgld import csgld
from .sgmcmc.sghmc import sghmc
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.inner_kernel_tuning import inner_kernel_tuning
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
from .vi.schrodinger_follmer import schrodinger_follmer
from .vi.svgd import svgd
from .sgmcmc import csgld as _csgld
from .sgmcmc import sghmc as _sghmc
from .sgmcmc import sgld as _sgld
from .sgmcmc import sgnht as _sgnht
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
from .vi import svgd as _svgd
from .vi.pathfinder import PathFinderAlgorithm

"""
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower
level to be mostly functional programming in nature and reducing boilerplate code.
"""


@dataclasses.dataclass
class GenerateSamplingAPI:
differentiable: Callable
init: Callable
build_kernel: Callable

def __call__(self, *args, **kwargs) -> SamplingAlgorithm:
return self.differentiable(*args, **kwargs)

def register_factory(self, name, callable):
setattr(self, name, callable)


@dataclasses.dataclass
class GenerateVariationalAPI:
differentiable: Callable
init: Callable
step: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> VIAlgorithm:
return self.differentiable(*args, **kwargs)


@dataclasses.dataclass
class GeneratePathfinderAPI:
differentiable: Callable
approximate: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> PathFinderAlgorithm:
return self.differentiable(*args, **kwargs)


def generate_top_level_api_from(module):
return GenerateSamplingAPI(
module.as_top_level_api, module.init, module.build_kernel
)


# MCMC
hmc = generate_top_level_api_from(_hmc)
nuts = generate_top_level_api_from(_nuts)
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh)
irmh = GenerateSamplingAPI(
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh
)
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc)
rmhmc = generate_top_level_api_from(_rmhmc)
mala = generate_top_level_api_from(_mala)
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian)
orbital_hmc = generate_top_level_api_from(periodic_orbital)

additive_step_random_walk = GenerateSamplingAPI(
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step
)

additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)

mclmc = generate_top_level_api_from(_mclmc)
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
ghmc = generate_top_level_api_from(_ghmc)
barker_proposal = generate_top_level_api_from(barker)

hmc_family = [hmc, nuts]

# SMC
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)

smc_family = [tempered_smc, adaptive_tempered_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
sgld = generate_top_level_api_from(_sgld)
sghmc = generate_top_level_api_from(_sghmc)
sgnht = generate_top_level_api_from(_sgnht)
csgld = generate_top_level_api_from(_csgld)
svgd = generate_top_level_api_from(_svgd)

# variational inference
meanfield_vi = GenerateVariationalAPI(
_meanfield_vi.as_top_level_api,
_meanfield_vi.init,
_meanfield_vi.step,
_meanfield_vi.sample,
)
schrodinger_follmer = GenerateVariationalAPI(
_schrodinger_follmer.as_top_level_api,
_schrodinger_follmer.init,
_schrodinger_follmer.step,
_schrodinger_follmer.sample,
)

pathfinder = GeneratePathfinderAPI(
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample
)


__all__ = [
"__version__",
"dual_averaging", # optimizers
"lbfgs",
"hmc", # mcmc
"dynamic_hmc",
"rmhmc",
"mala",
"mgrad_gaussian",
"nuts",
"orbital_hmc",
"additive_step_random_walk",
"rmh",
"irmh",
"mclmc",
"elliptical_slice",
"ghmc",
"barker_proposal",
"sgld", # stochastic gradient mcmc
"sghmc",
"sgnht",
"csgld",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
"chees_adaptation",
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"adaptive_tempered_smc", # smc
"tempered_smc",
"inner_kernel_tuning",
"meanfield_vi", # variational inference
"pathfinder",
"schrodinger_follmer",
"svgd",
"ess", # diagnostics
"rhat",
]
5 changes: 2 additions & 3 deletions blackjax/adaptation/pathfinder_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Pathinder warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple, Union
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

import blackjax.mcmc as mcmc
import blackjax.vi as vi
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.step_size import (
Expand Down Expand Up @@ -138,7 +137,7 @@ def final(warmup_state: PathfinderAdaptationState) -> tuple[float, Array]:


def pathfinder_adaptation(
algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts],
algorithm,
logdensity_fn: Callable,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
Expand Down
7 changes: 3 additions & 4 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Stan warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple, Union
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.mass_matrix import (
MassMatrixAdaptationState,
Expand Down Expand Up @@ -243,7 +242,7 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]:


def window_adaptation(
algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts],
algorithm,
logdensity_fn: Callable,
is_mass_matrix_diagonal: bool = True,
initial_step_size: float = 1.0,
Expand All @@ -252,7 +251,7 @@ def window_adaptation(
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
algorithms in the HMC fmaily.
algorithms in the HMC family. See Blackjax.hmc_family
Algorithms in the HMC family on a euclidean manifold depend on the value of
at least two parameters: the step size, related to the trajectory
Expand Down
29 changes: 12 additions & 17 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "barker_proposal"]
__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]


class BarkerState(NamedTuple):
Expand Down Expand Up @@ -128,7 +128,10 @@ def kernel(
return kernel


class barker_proposal:
def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
Gaussian base kernel.
Expand Down Expand Up @@ -179,24 +182,16 @@ class barker_proposal:
"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)
kernel = build_kernel()

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
step_size: float,
) -> SamplingAlgorithm:
kernel = cls.build_kernel()
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return init(position, logdensity_fn)

def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)
def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)

return SamplingAlgorithm(init_fn, step_fn)
return SamplingAlgorithm(init_fn, step_fn)


def _barker_sample_nd(key, mean, a, scale):
Expand Down
Loading

0 comments on commit 1bc6f93

Please sign in to comment.