-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Removal of Algorithm classes. (#657)
* more * removing export * removal of classes, tests passing * linter * fix on test * linter * removing parametrization on test * code review updates * exporting as_top_level_api in dynamic_hmc * linter * code review update: replace imports
- Loading branch information
Showing
30 changed files
with
893 additions
and
899 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,163 @@ | ||
import dataclasses | ||
from typing import Callable | ||
|
||
from blackjax._version import __version__ | ||
|
||
from .adaptation.chees_adaptation import chees_adaptation | ||
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size | ||
from .adaptation.meads_adaptation import meads_adaptation | ||
from .adaptation.pathfinder_adaptation import pathfinder_adaptation | ||
from .adaptation.window_adaptation import window_adaptation | ||
from .base import SamplingAlgorithm, VIAlgorithm | ||
from .diagnostics import effective_sample_size as ess | ||
from .diagnostics import potential_scale_reduction as rhat | ||
from .mcmc.barker import barker_proposal | ||
from .mcmc.dynamic_hmc import dynamic_hmc | ||
from .mcmc.elliptical_slice import elliptical_slice | ||
from .mcmc.ghmc import ghmc | ||
from .mcmc.hmc import hmc | ||
from .mcmc.mala import mala | ||
from .mcmc.marginal_latent_gaussian import mgrad_gaussian | ||
from .mcmc.mclmc import mclmc | ||
from .mcmc.nuts import nuts | ||
from .mcmc.periodic_orbital import orbital_hmc | ||
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh | ||
from .mcmc.rmhmc import rmhmc | ||
from .mcmc import barker | ||
from .mcmc import dynamic_hmc as _dynamic_hmc | ||
from .mcmc import elliptical_slice as _elliptical_slice | ||
from .mcmc import ghmc as _ghmc | ||
from .mcmc import hmc as _hmc | ||
from .mcmc import mala as _mala | ||
from .mcmc import marginal_latent_gaussian | ||
from .mcmc import mclmc as _mclmc | ||
from .mcmc import nuts as _nuts | ||
from .mcmc import periodic_orbital, random_walk | ||
from .mcmc import rmhmc as _rmhmc | ||
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk | ||
from .mcmc.random_walk import ( | ||
irmh_as_top_level_api, | ||
normal_random_walk, | ||
rmh_as_top_level_api, | ||
) | ||
from .optimizers import dual_averaging, lbfgs | ||
from .sgmcmc.csgld import csgld | ||
from .sgmcmc.sghmc import sghmc | ||
from .sgmcmc.sgld import sgld | ||
from .sgmcmc.sgnht import sgnht | ||
from .smc.adaptive_tempered import adaptive_tempered_smc | ||
from .smc.inner_kernel_tuning import inner_kernel_tuning | ||
from .smc.tempered import tempered_smc | ||
from .vi.meanfield_vi import meanfield_vi | ||
from .vi.pathfinder import pathfinder | ||
from .vi.schrodinger_follmer import schrodinger_follmer | ||
from .vi.svgd import svgd | ||
from .sgmcmc import csgld as _csgld | ||
from .sgmcmc import sghmc as _sghmc | ||
from .sgmcmc import sgld as _sgld | ||
from .sgmcmc import sgnht as _sgnht | ||
from .smc import adaptive_tempered | ||
from .smc import inner_kernel_tuning as _inner_kernel_tuning | ||
from .smc import tempered | ||
from .vi import meanfield_vi as _meanfield_vi | ||
from .vi import pathfinder as _pathfinder | ||
from .vi import schrodinger_follmer as _schrodinger_follmer | ||
from .vi import svgd as _svgd | ||
from .vi.pathfinder import PathFinderAlgorithm | ||
|
||
""" | ||
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable | ||
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower | ||
level to be mostly functional programming in nature and reducing boilerplate code. | ||
""" | ||
|
||
|
||
@dataclasses.dataclass | ||
class GenerateSamplingAPI: | ||
differentiable: Callable | ||
init: Callable | ||
build_kernel: Callable | ||
|
||
def __call__(self, *args, **kwargs) -> SamplingAlgorithm: | ||
return self.differentiable(*args, **kwargs) | ||
|
||
def register_factory(self, name, callable): | ||
setattr(self, name, callable) | ||
|
||
|
||
@dataclasses.dataclass | ||
class GenerateVariationalAPI: | ||
differentiable: Callable | ||
init: Callable | ||
step: Callable | ||
sample: Callable | ||
|
||
def __call__(self, *args, **kwargs) -> VIAlgorithm: | ||
return self.differentiable(*args, **kwargs) | ||
|
||
|
||
@dataclasses.dataclass | ||
class GeneratePathfinderAPI: | ||
differentiable: Callable | ||
approximate: Callable | ||
sample: Callable | ||
|
||
def __call__(self, *args, **kwargs) -> PathFinderAlgorithm: | ||
return self.differentiable(*args, **kwargs) | ||
|
||
|
||
def generate_top_level_api_from(module): | ||
return GenerateSamplingAPI( | ||
module.as_top_level_api, module.init, module.build_kernel | ||
) | ||
|
||
|
||
# MCMC | ||
hmc = generate_top_level_api_from(_hmc) | ||
nuts = generate_top_level_api_from(_nuts) | ||
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh) | ||
irmh = GenerateSamplingAPI( | ||
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh | ||
) | ||
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc) | ||
rmhmc = generate_top_level_api_from(_rmhmc) | ||
mala = generate_top_level_api_from(_mala) | ||
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian) | ||
orbital_hmc = generate_top_level_api_from(periodic_orbital) | ||
|
||
additive_step_random_walk = GenerateSamplingAPI( | ||
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step | ||
) | ||
|
||
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) | ||
|
||
mclmc = generate_top_level_api_from(_mclmc) | ||
elliptical_slice = generate_top_level_api_from(_elliptical_slice) | ||
ghmc = generate_top_level_api_from(_ghmc) | ||
barker_proposal = generate_top_level_api_from(barker) | ||
|
||
hmc_family = [hmc, nuts] | ||
|
||
# SMC | ||
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered) | ||
tempered_smc = generate_top_level_api_from(tempered) | ||
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) | ||
|
||
smc_family = [tempered_smc, adaptive_tempered_smc] | ||
"Step_fn returning state has a .particles attribute" | ||
|
||
# stochastic gradient mcmc | ||
sgld = generate_top_level_api_from(_sgld) | ||
sghmc = generate_top_level_api_from(_sghmc) | ||
sgnht = generate_top_level_api_from(_sgnht) | ||
csgld = generate_top_level_api_from(_csgld) | ||
svgd = generate_top_level_api_from(_svgd) | ||
|
||
# variational inference | ||
meanfield_vi = GenerateVariationalAPI( | ||
_meanfield_vi.as_top_level_api, | ||
_meanfield_vi.init, | ||
_meanfield_vi.step, | ||
_meanfield_vi.sample, | ||
) | ||
schrodinger_follmer = GenerateVariationalAPI( | ||
_schrodinger_follmer.as_top_level_api, | ||
_schrodinger_follmer.init, | ||
_schrodinger_follmer.step, | ||
_schrodinger_follmer.sample, | ||
) | ||
|
||
pathfinder = GeneratePathfinderAPI( | ||
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample | ||
) | ||
|
||
|
||
__all__ = [ | ||
"__version__", | ||
"dual_averaging", # optimizers | ||
"lbfgs", | ||
"hmc", # mcmc | ||
"dynamic_hmc", | ||
"rmhmc", | ||
"mala", | ||
"mgrad_gaussian", | ||
"nuts", | ||
"orbital_hmc", | ||
"additive_step_random_walk", | ||
"rmh", | ||
"irmh", | ||
"mclmc", | ||
"elliptical_slice", | ||
"ghmc", | ||
"barker_proposal", | ||
"sgld", # stochastic gradient mcmc | ||
"sghmc", | ||
"sgnht", | ||
"csgld", | ||
"window_adaptation", # mcmc adaptation | ||
"meads_adaptation", | ||
"chees_adaptation", | ||
"pathfinder_adaptation", | ||
"mclmc_find_L_and_step_size", # mclmc adaptation | ||
"adaptive_tempered_smc", # smc | ||
"tempered_smc", | ||
"inner_kernel_tuning", | ||
"meanfield_vi", # variational inference | ||
"pathfinder", | ||
"schrodinger_follmer", | ||
"svgd", | ||
"ess", # diagnostics | ||
"rhat", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.