diff --git a/.gitignore b/.gitignore index cad2ac38d..a4fdf64a5 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,17 @@ sftp-config.json *.local.* # MacOS storage files .DS_Store +.dockerignore +.ipynb_checkpoints/ +docs/_build/ +resources/benchmarks/datasets +resources/2d #CLion cash .idea/ # Build debug folder cmake-build-debug/ + +# .old folders +src/hierarchies/updaters/.old/ +test/.old/ +examples/gamma_hierarchy/.old/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 28e08f633..e79f07f43 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,7 @@ set(INCLUDE_PATHS ${CMAKE_CURRENT_LIST_DIR}/lib/math ${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/boost_1.72.0 ${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/eigen_3.3.9 + ${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/sundials_5.5.0/include # TBB already included ${CMAKE_CURRENT_BINARY_DIR} ${protobuf_SOURCE_DIR}/src diff --git a/README.md b/README.md index d924dacd5..7cb6972a4 100644 --- a/README.md +++ b/README.md @@ -8,17 +8,15 @@ Current state of the software: - `bayesmix` performs inference for mixture models of the kind + + + -where P is either the Dirichlet process or the Pitman--Yor process - -- We currently support univariate and multivariate location-scale mixture of Gaussian densities - -- Inference is carried out using algorithms such as Algorithm 2 in [Neal (2000)](http://www.stat.columbia.edu/npbayes/papers/neal_sampling.pdf) - -- Serialization of the MCMC chains is possible using Google's [Protocol Buffers](https://developers.google.com/protocol-buffers) aka `protobuf` +For descriptions of the models supported in our library, discussion of software design, and examples, please refer to the following paper: https://arxiv.org/abs/2205.08144 # Installation @@ -29,7 +27,7 @@ where P is either the Dirichlet process or the Pitman--Yor process To install and use `bayesmix`, please `cd` to the folder to which you wish to install it, and clone this repository with the following command-line instruction: ```shell -git clone --recurse-submodule git@github.com:bayesmix-dev/bayesmix.git +git clone --recurse-submodules git@github.com:bayesmix-dev/bayesmix.git ``` Then, by using `cd bayesmix`, you will enter the newly downloaded folder. @@ -39,8 +37,8 @@ To build the executable for the main file `run_mcmc.cc`, please use the followin ```shell mkdir build cd build -cmake .. -DDISABLE_DOCS=on -DDISABLE_BENCHMARKS=on -DDISABLE_TESTS=on -make run +cmake .. -DDISABLE_DOCS=ON -DDISABLE_BENCHMARKS=ON -DDISABLE_TESTS=ON +make run_mcmc cd .. ``` @@ -97,3 +95,4 @@ Documentation is available at https://bayesmix.readthedocs.io. # Contributions are welcome! Please check out [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to collaborate with us. +You can also head to our [issues page](https://github.com/bayesmix-dev/bayesmix/issues) to check for useful enhancements needed. diff --git a/benchmarks/lpd_grid.cc b/benchmarks/lpd_grid.cc index d297f5943..507b20d0b 100644 --- a/benchmarks/lpd_grid.cc +++ b/benchmarks/lpd_grid.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include "src/utils/distributions.h" #include "utils.h" @@ -37,7 +37,6 @@ Eigen::VectorXd lpdf_naive(const Eigen::MatrixXd &x, return out; } - Eigen::VectorXd lpdf_fully_optimized(const Eigen::MatrixXd &x, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, @@ -84,7 +83,6 @@ static void BM_gauss_lpdf_naive(benchmark::State &state) { } } - static void BM_gauss_lpdf_fully_optimized(benchmark::State &state) { int dim = state.range(0); Eigen::VectorXd mean = Eigen::VectorXd::Zero(dim); @@ -99,7 +97,6 @@ static void BM_gauss_lpdf_fully_optimized(benchmark::State &state) { } } - BENCHMARK(BM_gauss_lpdf_cov)->RangeMultiplier(2)->Range(2, 2 << 4); BENCHMARK(BM_gauss_lpdf_cov)->RangeMultiplier(2)->Range(2, 2 << 4); BENCHMARK(BM_gauss_lpdf_naive)->RangeMultiplier(2)->Range(2, 2 << 4); diff --git a/benchmarks/nnw_marg_lpdf.cc b/benchmarks/nnw_marg_lpdf.cc index 696a8a198..a99314692 100644 --- a/benchmarks/nnw_marg_lpdf.cc +++ b/benchmarks/nnw_marg_lpdf.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include "utils.h" diff --git a/docs/Doxyfile.in b/docs/Doxyfile.in index 30cda29e9..6b163fa99 100644 --- a/docs/Doxyfile.in +++ b/docs/Doxyfile.in @@ -1819,7 +1819,7 @@ PAPER_TYPE = a4 # If left blank no extra packages will be included. # This tag requires that the tag GENERATE_LATEX is set to YES. -EXTRA_PACKAGES = +EXTRA_PACKAGES = bm # The LATEX_HEADER tag can be used to specify a personal LaTeX header for the # generated LaTeX document. The header should contain everything until the first @@ -1893,7 +1893,7 @@ USE_PDFLATEX = YES # The default value is: NO. # This tag requires that the tag GENERATE_LATEX is set to YES. -LATEX_BATCHMODE = NO +LATEX_BATCHMODE = YES # If the LATEX_HIDE_INDICES tag is set to YES then doxygen will not include the # index chapters (such as File Index, Compound Index, etc.) in the output. diff --git a/docs/algorithms.rst b/docs/algorithms.rst index c61efc257..e15aff394 100644 --- a/docs/algorithms.rst +++ b/docs/algorithms.rst @@ -20,6 +20,9 @@ Algorithms .. doxygenclass:: Neal8Algorithm :project: bayesmix :members: +.. doxygenclass:: SplitAndMergeAlgorithm + :project: bayesmix + :members: .. doxygenclass:: ConditionalAlgorithm :project: bayesmix :members: diff --git a/docs/conf.py b/docs/conf.py index 3501007dd..9acbf232a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,8 +57,6 @@ def configureDoxyfile(input_dir, output_dir): html_theme = 'haiku' -html_static_path = ['_static'] - highlight_language = 'cpp' imgmath_latex = 'latex' diff --git a/docs/hierarchies.rst b/docs/hierarchies.rst index 21c6fe1d7..c6dedcab9 100644 --- a/docs/hierarchies.rst +++ b/docs/hierarchies.rst @@ -6,6 +6,20 @@ Hierarchies In our algorithms, we store a vector of hierarchies, each of which represent a parameter :math:`\theta_h`. The hierarchy implements all the methods needed to update :math:`\theta_h`: sampling from the prior distribution :math:`P_0`, the full-conditional distribution (given the data {:math:`y_i` such that :math:`c_i = h`} ) and so on. +In BayesMix, each choice of :math:`G_0` is implemented in a different ``PriorModel`` object and each choice of :math:`k(\cdot \mid \cdot)` in a ``Likelihood`` object, so that it is straightforward to create a new ``Hierarchy`` using one of the already implemented priors or likelihoods. +The sampling from the full conditional of :math:`\theta_h` is performed in an ``Updater`` class. +`State` classes are used to store parameters :math:`\theta_h` s of every mixture component. +Their main purpose is to handle serialization and de-serialization of the state + +.. toctree:: + :maxdepth: 1 + :caption: API: hierarchies submodules + + likelihoods + prior_models + updaters + states + ------------------------- Main operations performed @@ -13,17 +27,17 @@ Main operations performed A hierarchy must be able to perform the following operations: -1. Sample from the prior distribution: generate :math:`\theta_h \sim P_0` [``sample_prior``] -2. Sample from the 'full conditional' distribution: generate theta_h from the distribution :math:`p(\theta_h \mid \cdots ) \propto P_0(\theta_h) \prod_{i: c_i = h} k(y_i | \theta_h)` [``sample_full_conditional``] -3. Update the hyperparameters involved in :math:`P_0` [``update_hypers``] -4. Evaluate the likelihood in one point, i.e. :math:`k(x | \theta_h)` for theta_h the current value of the parameters [``like_lpdf``] -5. When :math:`k` and :math:`P_0` are conjugate, we must also be able to compute the marginal/prior predictive distribution in one point, i.e. :math:`m(x) = \int k(x | \theta) P_0(d\theta)`, and the conditional predictive distribution :math:`m(x | \textbf{y} ) = \int k(x | \theta) P_0(d\theta | \{y_i: c_i = h\})` [``prior_pred_lpdf``, ``conditional_pred_lpdf``] +a. Sample from the prior distribution: generate :math:`\theta_h \sim P_0` [``sample_prior``] +b. Sample from the 'full conditional' distribution: generate theta_h from the distribution :math:`p(\theta_h \mid \cdots ) \propto P_0(\theta_h) \prod_{i: c_i = h} k(y_i | \theta_h)` [``sample_full_conditional``] +c. Update the hyperparameters involved in :math:`P_0` [``update_hypers``] +d. Evaluate the likelihood in one point, i.e. :math:`k(x | \theta_h)` for theta_h the current value of the parameters [``like_lpdf``] +e. When :math:`k` and :math:`P_0` are conjugate, we must also be able to compute the marginal/prior predictive distribution in one point, i.e. :math:`m(x) = \int k(x | \theta) P_0(d\theta)`, and the conditional predictive distribution :math:`m(x | \textbf{y} ) = \int k(x | \theta) P_0(d\theta | \{y_i: c_i = h\})` [``prior_pred_lpdf``, ``conditional_pred_lpdf``] Moreover, the following utilities are needed: -6. write the current state :math:`\theta_h` into a appropriately defined Protobuf message [``write_state_to_proto``] -7. restore theta_h from a given Protobuf message [``set_state_from_proto``] -8. write the values of the hyperparameters in :math:`P_0` to a Protobuf message [``write_hypers_to_proto``] +f. write the current state :math:`\theta_h` into a appropriately defined Protobuf message [``write_state_to_proto``] +g. restore theta_h from a given Protobuf message [``set_state_from_proto``] +h. write the values of the hyperparameters in :math:`P_0` to a Protobuf message [``write_hypers_to_proto``] In each hierarchy, we also keep track of which data points are allocated to the hierarchy. @@ -44,19 +58,18 @@ A template class ``BaseHierarchy`` inherits from ``AbstractHierarchy`` and imple Instead, child classes must implement: -1. ``like_lpdf``: evaluates :math:`k(x | \theta_h)` -2. ``marg_lpdf``: evaluates m(x) given some parameters :math:`\theta_h` (could be both the hyperparameters in :math:`P_0` or the paramters given by the full conditionals) -3. ``draw``: samples from :math:`P_0` given the parameters -4. ``clear_summary_statistics``: clears all the summary statistics -5. ``update_hypers``: performs the update of parameters in :math:`P_0` given all the :math:`\theta_h` (passed as a vector of protobuf Messages) -6. ``initialize_state``: initializes the current :math:`\theta_h` given the hyperparameters in :math:`P_0` -7. ``initialize_hypers``: initializes the hyperparameters in :math:`P_0` given their hyperprior -8. ``update_summary_statistics``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy -9. ``get_posterior_parameters``: returns the paramters of the full conditional distribution **possible only when** :math:`P_0` **and** :math:`k` **are conjugate** -10. ``set_state_from_proto`` -11. ``write_state_to_proto`` -12. ``write_hypers_to_proto`` - +a. ``like_lpdf``: evaluates :math:`k(x | \theta_h)` +b. ``marg_lpdf``: evaluates m(x) given some parameters :math:`\theta_h` (could be both the hyperparameters in :math:`P_0` or the paramters given by the full conditionals) +c. ``draw``: samples from :math:`P_0` given the parameters +d. ``clear_summary_statistics``: clears all the summary statistics +e. ``update_hypers``: performs the update of parameters in :math:`P_0` given all the :math:`\theta_h` (passed as a vector of protobuf Messages) +f. ``initialize_state``: initializes the current :math:`\theta_h` given the hyperparameters in :math:`P_0` +g. ``initialize_hypers``: initializes the hyperparameters in :math:`P_0` given their hyperprior +h. ``update_summary_statistics``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy +i. ``get_posterior_parameters``: returns the paramters of the full conditional distribution **possible only when** :math:`P_0` **and** :math:`k` **are conjugate** +j. ``set_state_from_proto`` +k. ``write_state_to_proto`` +l. ``write_hypers_to_proto`` Note that not all of these members are declared virtual in ``AbstractHierarchy`` or ``BaseHierarchy``: this is because virtual members are only the ones that must be called from outside the ``Hierarchy``, the other ones are handled via CRTP. Not having them virtual saves a lot of lookups in the vtables. The ``BaseHierarchy`` class takes 4 template parameters: @@ -66,12 +79,12 @@ The ``BaseHierarchy`` class takes 4 template parameters: 3. ``Hyperparams`` is usually a struct representing the parameters in :math:`P_0` 4. ``Prior`` must be a protobuf object encoding the prior parameters. +.. Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models. -Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models. -------- -Classes -------- +---------------- +Abstract Classes +---------------- .. doxygenclass:: AbstractHierarchy :project: bayesmix @@ -79,9 +92,11 @@ Classes .. doxygenclass:: BaseHierarchy :project: bayesmix :members: -.. doxygenclass:: ConjugateHierarchy - :project: bayesmix - :members: + +--------------------------------- +Classes for Conjugate Hierarchies +--------------------------------- + .. doxygenclass:: NNIGHierarchy :project: bayesmix :members: @@ -91,3 +106,17 @@ Classes .. doxygenclass:: LinRegUniHierarchy :project: bayesmix :members: + +------------------------------------- +Classes for Non-Conjugate Hierarchies +------------------------------------- + +.. doxygenclass:: NNxIGHierarchy + :project: bayesmix + :members: +.. doxygenclass:: LapNIGHierarchy + :project: bayesmix + :members: +.. doxygenclass:: FAHierarchy + :project: bayesmix + :members: diff --git a/docs/index.rst b/docs/index.rst index 59d19db44..271387414 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,7 +32,8 @@ There are currently three submodules to the ``bayesmix`` library, represented by Further, we employ Protocol buffers for several purposes, including serialization. The list of all protos with their docs is available in the ``protos`` link below. .. toctree:: - :maxdepth: 1 + :maxdepth: 2 + :titlesonly: :caption: API: library submodules algorithms @@ -43,11 +44,15 @@ Further, we employ Protocol buffers for several purposes, including serializatio utils - Tutorials ========= -:doc:`tutorial` +.. toctree:: + :maxdepth: 1 + + tutorial + +.. :doc:`tutorial` Python interface diff --git a/docs/likelihoods.rst b/docs/likelihoods.rst new file mode 100644 index 000000000..72f2d73cd --- /dev/null +++ b/docs/likelihoods.rst @@ -0,0 +1,85 @@ +bayesmix/hierarchies/likelihoods + +Likelihoods +=========== + +The ``Likelihood`` sub-module represents the likelihood we have assumed for the data in a given cluster. Each ``Likelihood`` class represents the sampling model + +.. math:: + y_1, \ldots, y_k \mid \bm{\tau} \stackrel{\small\mathrm{iid}}{\sim} f(\cdot \mid \bm{\tau}) + +for a specific choice of the probability density function :math:`f`. + +------------------------- +Main operations performed +------------------------- + +A likelihood object must be able to perform the following operations: + +a. First of all, we require the ``lpdf()`` and ``lpdf\_grid()`` methods, which simply evaluate the loglikelihood in a given point or in a grid of points (also in case of a \emph{dependent} likelihood, i.e., with covariates associated to each observation) [``lpdf()`` and ``lpdf_grid``] +b. In case you want to rely on a Metropolis-like updater, the likelihood needs to evaluation of the likelihood of the whole cluster starting from the vector of unconstrained parameters [``cluster_lpdf_from_unconstrained()``]. Observe that the ``AbstractLikelihood`` class provides two such methods, one returning a ``double`` and one returning a ``stan::math::var``. The latter is used to automatically compute the gradient of the likelihood via Stan's automatic differentiation, if needed. In practice, users do not need to implement both methods separately and can implement only one templated method +c. manage the insertion and deletion of a datum in the cluster [``add_datum`` and ``remove_datum``] +d. update the summary statistics associated to the likelihood [``update_summary_statistics``]. Summary statistics (when available) are used to evaluate the likelihood function on the whole cluster, as well as to perform the posterior updates of :math:`\bm{\tau}`. This usually gives a substantial speedup + +-------------- +Code structure +-------------- + +In principle, the ``Likelihood`` classes are responsible only of evaluating the log-likelihood function given a specific choice of parameters :math:`\bm{\tau}`. +Therefore, a simple inheritance structure would seem appropriate. However, the nature of the parameters :math:`\bm{\tau}` can be very different across different models (think for instance of the difference between the univariate normal and the multivariate normal paramters). As such, we employ CRTP to manage the polymorphic nature of ``Likelihood`` classes. + +The class ``AbstractLikelihood`` defines the API, i.e. all the methods that need to be called from outside of a ``Likelihood`` class. +A template class ``BaseLikelihood`` inherits from ``AbstractLikelihood`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. + +Instead, child classes **must** implement: + +a. ``compute_lpdf``: evaluates :math:`k(x \mid \theta_h)` +b. ``update_sum_stats``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy +c. ``clear_summary_statistics``: clears all the summary statistics +d. ``is_dependent``: defines if the given likelihood depends on covariates +e. ``is_multivariate``: defines if the given likelihood is for multivariate data + +In case the likelihood needs to be used in a Metropolis-like updater, child classes **should** also implement: + +f. ``cluster_lpdf_from_unconstrained``: evaluates :math:`\prod_{i: c_i = h} k(x_i \mid \tilde{\theta}_h)`, where :math:`\tilde{\theta}_h` is the vector of unconstrained parameters. + +---------------- +Abstract Classes +---------------- + +.. doxygenclass:: AbstractLikelihood + :project: bayesmix + :members: +.. doxygenclass:: BaseLikelihood + :project: bayesmix + :members: + +---------------------------------- +Classes for Univariate Likelihoods +---------------------------------- + +.. doxygenclass:: UniNormLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: UniLinRegLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: LaplaceLikelihood + :project: bayesmix + :members: + :protected-members: + +------------------------------------ +Classes for Multivariate Likelihoods +------------------------------------ + +.. doxygenclass:: MultiNormLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: FALikelihood + :project: bayesmix + :members: + :protected-members: diff --git a/docs/mixings.rst b/docs/mixings.rst index 67d2b1074..64f3b593f 100644 --- a/docs/mixings.rst +++ b/docs/mixings.rst @@ -34,6 +34,9 @@ Classes .. doxygenclass:: PitYorMixing :project: bayesmix :members: +.. doxygenclass:: MixtureFiniteMixing + :project: bayesmix + :members: .. doxygenclass:: TruncatedSBMixing :project: bayesmix :members: diff --git a/docs/prior_models.rst b/docs/prior_models.rst new file mode 100644 index 000000000..002ab1f85 --- /dev/null +++ b/docs/prior_models.rst @@ -0,0 +1,87 @@ +bayesmix/hierarchies/prior_models + +Prior Models +============ + +A ``PriorModel`` represents the prior for the parameters in the likelihood, i.e. + +.. math:: + \bm{\tau} \sim G_{0} + +with :math:`G_{0}` being a suitable prior on the parameters space. We also allow for more flexible priors adding further level of randomness (i.e. the hyperprior) on the parameter characterizing :math:`G_{0}` + +------------------------- +Main operations performed +------------------------- + +A likelihood object must be able to perform the following operations: + +a. First of all, ``lpdf()`` and ``lpdf_from_unconstrained()`` methods evaluate the log-prior density function at the current state :math:`\bm \tau` or its unconstrained representation. +In particular, ``lpdf_from_unconstrained()`` is needed by Metropolis-like updaters. + +b. The ``sample()`` method generates a draw from the prior distribution. If ``hier_hypers`` is ``nullptr, the prior hyperparameter values are used. +To allow sampling from the full conditional distribution in case of semi-congugate hierarchies, we introduce the ``hier_hypers`` parameter, which is a pointer to a ``Protobuf`` message storing the hierarchy hyperaprameters to use for the sampling. + +c. The ``update_hypers()`` method updates the prior hyperparameters, given the vector of all cluster states. + + +-------------- +Code structure +-------------- + +As for the ``Likelihood`` classes we employ the Curiously Recurring Template Pattern to manage the polymorphic nature of ``PriorModel`` classes. + +The class ``AbstractPriorModel`` defines the API, i.e. all the methods that need to be called from outside of a ``PrioModel`` class. +A template class ``BasePriorModel`` inherits from ``AbstractPriorModel`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. + +Instead, child classes **must** implement: + +a. ``lpdf``: evaluates :math:`G_0(\theta_h)` +b. ``sample``: samples from :math:`G_0` given a hyperparameters (passed as a pointer). If ``hier_hypers`` is ``nullptr``, the prior hyperparameter values are used. +c. ``set_hypers_from_proto``: sets the hyperparameters from a ``Probuf`` message +d. ``get_hypers_proto``: returns the hyperparameters as a ``Probuf`` message +e. ``initialize_hypers``: provides a default initialization of hyperparameters + +In case you want to use a Metropolis-like updater, child classes **should** also implement: + +f. ``lpdf_from_unconstrained``: evaluates :math:`G_0(\tilde{\theta}_h)`, where :math:`\tilde{\theta}_h` is the vector of unconstrained parameters. + +---------------- +Abstract Classes +---------------- + +.. doxygenclass:: AbstractPriorModel + :project: bayesmix + :members: +.. doxygenclass:: BasePriorModel + :project: bayesmix + :members: + +-------------------- +Non-abstract Classes +-------------------- + +.. doxygenclass:: NIGPriorModel + :project: bayesmix + :members: + :protected-members: + +.. doxygenclass:: NxIGPriorModel + :project: bayesmix + :members: + :protected-members: + +.. doxygenclass:: NWPriorModel + :project: bayesmix + :members: + :protected-members: + +.. doxygenclass:: MNIGPriorModel + :project: bayesmix + :members: + :protected-members: + +.. doxygenclass:: FAPriorModel + :project: bayesmix + :members: + :protected-members: diff --git a/docs/protos.html b/docs/protos.html index abc4c641b..26b9929a3 100644 --- a/docs/protos.html +++ b/docs/protos.html @@ -3,8 +3,12 @@ Protocol Documentation - - + + - - + -

Protocol Documentation

Table of Contents

- - -
-

algorithm_id.proto

Top -
-

- - - - -

AlgorithmId

-

Enum for the different types of algorithms.

References

[1] R. M. Neal, Markov Chain Sampling Methods for Dirichlet Process Mixture Models. JCGS(2000)

[2] H. Ishwaran and L. F. James, Gibbs Sampling Methods for Stick-Breaking Priors. JASA(2001)

[3] S. Jain and R. M. Neal, A Split-Merge Markov Chain Monte Carlo Procedure for the Dirichlet Process Mixture Model. JCGS (2004)

[4] M. Kalli, J. Griffin and S. G. Walker, Slice sampling mixture models. Stat and Comp. (2011)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_ALGORITHM0

Neal21

Neal's Algorithm 2, see [1]

Neal32

Neal's Algorithm 3, see [1]

Neal83

Neal's Algorithm 8, see [1]

BlockedGibbs4

Ishwaran and James Blocked Gibbs, see [2]

SplitMerge5

Jain and Neal's Split&Merge, see [3]. NOT IMPLEMENTED YET!

Slice6

Slice sampling, see [4]. NOT IMPLEMENTED YET!

- - - - - - - -
-

algorithm_params.proto

Top -
-

- - -

AlgorithmParams

-

Parameters used in the BaseAlgorithm class and childs.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
algo_idstring

Id of the Algorithm. Must match the ones in the AlgorithmId enum

rng_seeduint32

Seed for the random number generator

iterationsuint32

Total number of iterations of the MCMC chain

burninuint32

Number of iterations to discard as burn-in

init_num_clustersuint32

Number of clusters to initialize the algorithm. It may be overridden by conditional mixings for which the number of components is fixed (e.g. TruncatedSBMixing). In this case, this value is ignored.

neal8_n_auxuint32

Number of auxiliary unique values for the Neal8 algorithm

splitmerge_n_restr_gs_updatesuint32

Number of restricted GS scans for each MH step.

splitmerge_n_mh_updatesuint32

Number of MH updates for each iteration of Split and Merge algorithm.

splitmerge_n_full_gs_updatesuint32

Number of full GS scans for each iteration of Split and Merge algorithm.

- - - - - - - - - - - - - -
-

algorithm_state.proto

Top -
-

- - -

AlgorithmState

-

This message represents the state of a Gibbs sampler for

a mixture model. All algorithms must be able to handle this

message, by filling it with the current state of the sampler

in the `get_state_as_proto` method.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
cluster_statesAlgorithmState.ClusterStaterepeated

The state of each cluster

cluster_allocsint32repeated

Vector of allocations into clusters, one for each observation

mixing_stateMixingState

The state of the `Mixing`

iteration_numint32

The iteration number

hierarchy_hypersAlgorithmState.HierarchyHypers

The current values of the hyperparameters of the hierarchy

- - - - - -

AlgorithmState.ClusterState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
uni_ls_stateUniLSState

State of a univariate location-scale family

multi_ls_stateMultiLSState

State of a multivariate location-scale family

lin_reg_uni_ls_stateLinRegUniLSState

State of a linear regression univariate location-scale family

general_stateVector

Just a vector of doubles

fa_stateFAState

State of a Mixture of Factor Analysers

cardinalityint32

How many observations are in this cluster

- - - - - -

AlgorithmState.HierarchyHypers

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fake_priorEmptyPrior

nnig_stateNIGDistribution

nnw_stateNWDistribution

lin_reg_uni_stateMultiNormalIGDistribution

lapnig_stateLapNIGState

fa_stateFAPriorDistribution

- - - - - - - - - - - - - -
-

distribution.proto

Top -
-

- - -

BetaDistribution

-

Parameters defining a beta distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shape_adouble

shape_bdouble

- - - - - -

GammaDistribution

-

Parameters defining a gamma distribution with density

f(x) = x^(shape-1) * exp(-rate * x) / Gamma(shape)

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shapedouble

ratedouble

- - - - - -

InvWishartDistribution

-

Parameters defining an Inverse Wishart distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
deg_freedouble

scaleMatrix

- - - - - -

MultiNormalDistribution

-

Parameters defining a multivariate normal distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

varMatrix

- - - - - -

MultiNormalIGDistribution

-

Parameters for the Normal Inverse Gamma distribution commonly employed in

linear regression models, with density

f(beta, var) = N(beta | mean, var * var_scaling^{-1}) * IG(var | shape, scale)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

var_scalingMatrix

shapedouble

scaledouble

- - - - - -

NIGDistribution

-

Parameters of a Normal Inverse-Gamma distribution

with density

f(x, y) = N(x | mu, y/var_scaling) * IG(y | shape, scale)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

var_scalingdouble

shapedouble

scaledouble

- - - - - -

NWDistribution

-

Parameters of a Normal Wishart distribution

with density

f(x, y) = N(x | mu, (y * var_scaling)^{-1}) * IW(y | deg_free, scale)

where x is a vector and y is a matrix (spd)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

var_scalingdouble

deg_freedouble

scaleMatrix

- - - - - -

UniNormalDistribution

-

Parameters defining a univariate normal distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

- - - - - - - - - - - - - -
-

hierarchy_id.proto

Top -
-

- - - - -

HierarchyId

-

Enum for the different types of Hierarchy.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_HIERARCHY0

NNIG1

Normal - Normal Inverse Gamma

NNW2

Normal - Normal Wishart

LinRegUni3

Linear Regression (univariate response)

LapNIG4

Laplace - Normal Inverse Gamma

FA5

Factor Analysers

- - - - - - - -
-

hierarchy_prior.proto

Top -
-

- - -

EmptyPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fake_fielddouble

- - - - - -

FAPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesFAPriorDistribution

- - - - - -

FAPriorDistribution

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mutildeVector

betaVector

phidouble

alpha0double

quint32

- - - - - -

LapNIGPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesLapNIGState

- - - - - -

LapNIGState

-

Prior for the parameters of the base measure in a Laplace - Normal Inverse Gamma hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

mh_mean_vardouble

mh_log_scale_vardouble

- - - - - -

LinRegUniPrior

-

Prior for the parameters of the base measure in a Normal mixture model with a covariate-dependent

location.

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesMultiNormalIGDistribution

- - - - - -

NNIGPrior

-

Prior for the parameters of the base measure in a Normal-Normal Inverse Gamma hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesNIGDistribution

no prior, just fixed values

normal_mean_priorNNIGPrior.NormalMeanPrior

prior on the mean

ngg_priorNNIGPrior.NGGPrior

prior on the mean, var_scaling, and scale

- - - - - -

NNIGPrior.NGGPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorUniNormalDistribution

var_scaling_priorGammaDistribution

shapedouble

scale_priorGammaDistribution

- - - - - -

NNIGPrior.NormalMeanPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorUniNormalDistribution

var_scalingdouble

shapedouble

scaledouble

- - - - - -

NNWPrior

-

Prior for the parameters of the base measure in a Normal-Normal Wishart hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesNWDistribution

no prior, just fixed values

normal_mean_priorNNWPrior.NormalMeanPrior

prior on the mean

ngiw_priorNNWPrior.NGIWPrior

prior on the mean, var_scaling, and scale

- - - - - -

NNWPrior.NGIWPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorMultiNormalDistribution

var_scaling_priorGammaDistribution

deg_freedouble

scale_priorInvWishartDistribution

- - - - - -

NNWPrior.NormalMeanPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorMultiNormalDistribution

var_scalingdouble

deg_freedouble

scaleMatrix

- - - - - - - - - - - - - -
-

ls_state.proto

Top -
-

- - -

FAState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
muVector

psiVector

etaMatrix

lambdaMatrix

- - - - - -

LinRegUniLSState

-

Parameters of a univariate linear regression

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
regression_coeffsVector

regression coefficients

vardouble

variance of the noise

- - - - - -

MultiLSState

-

Parameters of a multivariate location-scale family of distributions,

parameterized by mean and precision (inverse of variance). For

convenience, we also store the Cholesky factor of the precision matrix.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

precMatrix

prec_cholMatrix

- - - - - -

UniLSState

-

Parameters of a univariate location-scale family of distributions.

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

- - - - - - - - - - - - - -
-

matrix.proto

Top -
-

- - -

Matrix

-

Message representing a matrix of doubles.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
rowsint32

number of rows

colsint32

number of columns

datadoublerepeated

matrix elements

rowmajorbool

if true, the data is read in row-major order

- - - - - -

Vector

-

Message representing a vector of doubles.

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
sizeint32

number of elements in the vector

datadoublerepeated

vector elements

- - - - - - - - - - - - - -
-

mixing_id.proto

Top -
-

- - - - -

MixingId

-

Enum for the different types of Mixing.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_MIXING0

DP1

Dirichlet Process

PY2

Pitman-Yor Process

LogSB3

Logit Stick-Breaking Process

TruncSB4

Truncated Stick-Breaking Process

MFM5

Mixture of finite mixtures

- - - - - - - -
-

mixing_prior.proto

Top -
-

- - -

DPPrior

-

Prior for the concentration parameter of a Dirichlet process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valueDPState

No prior, just a fixed value

gamma_priorDPPrior.GammaPrior

Gamma prior on the total mass

- - - - - -

DPPrior.GammaPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmass_priorGammaDistribution

- - - - - -

LogSBPrior

-

Definition of the parameters of a Logit-Stick Breaking process.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
normal_priorMultiNormalDistribution

Normal prior on the regression coefficients

step_sizedouble

Steps size for the MALA algorithm used for posterior inference (TODO: move?)

num_componentsuint32

Number of components in the process

- - - - - -

MFMPrior

-

Prior for the Poisson rate and Dirichlet parameters of a MFM (Finite Dirichlet) process.

For the moment, we only support fixed values

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valueMFMState

No prior, just a fixed value

- - - - - -

PYPrior

-

Prior for the strength and discount parameters of a Pitman-Yor process.

For the moment, we only support fixed values

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesPYState

- - - - - -

TruncSBPrior

-

Definition of the parameters of a truncated Stick-Breaking process

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
beta_priorsTruncSBPrior.BetaPriors

General stick-breaking distributions

dp_priorTruncSBPrior.DPPrior

Truncated Dirichlet process

py_priorTruncSBPrior.PYPrior

Truncated Pitman-Yor process

mfm_priorTruncSBPrior.MFMPrior

num_componentsuint32

Number of components in the process

- - - - - -

TruncSBPrior.BetaPriors

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
beta_distributionsBetaDistributionrepeated

General stick-breaking distributions

- - - - - -

TruncSBPrior.DPPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

- - - - - -

TruncSBPrior.MFMPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

- - - - - -

TruncSBPrior.PYPrior

-

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
strengthdouble

Truncated Pitman-Yor process

discountdouble

- - - - - - - - - - - - - -
-

mixing_state.proto

Top -
-

- - -

DPState

-

State of a Dirichlet process

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

the total mass of the DP

- - - - - -

LogSBState

-

State of a Logit-Stick Breaking process

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
regression_coeffsMatrix

Num_Components x Num_Features matrix. Each row is the regression coefficients for a component.

- - - - - -

MFMState

-

State of a MFM (Finite Dirichlet) process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
lambdadouble

rate parameter of Poisson prior on number of compunents of the MFM

gammadouble

parameter of the dirichlet distribution for the mixing weights

- - - - - -

MixingState

-

Wrapper of all possible mixing states into a single oneof

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
dp_stateDPState

py_statePYState

log_sb_stateLogSBState

trunc_sb_stateTruncSBState

mfm_stateMFMState

- - - - - -

PYState

-

State of a Pitman-Yor process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
strengthdouble

discountdouble

- - - - - -

TruncSBState

-

State of a truncated sitck breaking process. For convenice we store also the logarithm of the weights

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
sticksVector

logweightsVector

- - - - - - - - - - - - - -
-

semihdp.proto

Top -
-

- - -

SemiHdpParams

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
pseudo_priorSemiHdpParams.PseudoPriorParams

dirichlet_concentrationdouble

rest_allocs_updatestring

Either "full", "metro_base", "metro_dist"

totalmass_restdouble

totalmass_hdpdouble

w_priorSemiHdpParams.WPriorParams

- - - - - -

SemiHdpParams.PseudoPriorParams

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
card_weightdouble

mean_perturb_sddouble

var_perturb_fracdouble

- - - - - -

SemiHdpParams.WPriorParams

-

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shape1double

shape2double

- - - - - -

SemiHdpState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
restaurantsSemiHdpState.RestaurantStaterepeated

groupsSemiHdpState.GroupStaterepeated

tausSemiHdpState.ClusterStaterepeated

cint32repeated

wdouble

- - - - - -

SemiHdpState.ClusterState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
uni_ls_stateUniLSState

multi_ls_stateMultiLSState

lin_reg_uni_ls_stateLinRegUniLSState

cardinalityint32

- - - - - -

SemiHdpState.GroupState

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
cluster_allocsint32repeated

- - - - - -

SemiHdpState.RestaurantState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
theta_starsSemiHdpState.ClusterStaterepeated

n_by_clusint32repeated

table_to_sharedint32repeated

table_to_idioint32repeated

- - - - - - - - - - - - +
+

algorithm_id.proto

+ Top +
+

+ +

AlgorithmId

+

Enum for the different types of algorithms.

+

References

+

+ [1] R. M. Neal, Markov Chain Sampling Methods for Dirichlet Process + Mixture Models. JCGS(2000) +

+

+ [2] H. Ishwaran and L. F. James, Gibbs Sampling Methods for Stick-Breaking + Priors. JASA(2001) +

+

+ [3] S. Jain and R. M. Neal, A Split-Merge Markov Chain Monte Carlo + Procedure for the Dirichlet Process Mixture Model. JCGS (2004) +

+

+ [4] M. Kalli, J. Griffin and S. G. Walker, Slice sampling mixture models. + Stat and Comp. (2011) +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_ALGORITHM0

Neal21

Neal's Algorithm 2, see [1]

Neal32

Neal's Algorithm 3, see [1]

Neal83

Neal's Algorithm 8, see [1]

BlockedGibbs4

Ishwaran and James Blocked Gibbs, see [2]

SplitMerge5 +

+ Jain and Neal's Split&Merge, see [3]. NOT IMPLEMENTED YET! +

+
Slice6

Slice sampling, see [4]. NOT IMPLEMENTED YET!

+ +
+

algorithm_params.proto

+ Top +
+

+ +

AlgorithmParams

+

Parameters used in the BaseAlgorithm class and childs.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
algo_idstring +

+ Id of the Algorithm. Must match the ones in the AlgorithmId enum +

+
rng_seeduint32

Seed for the random number generator

iterationsuint32

Total number of iterations of the MCMC chain

burninuint32

Number of iterations to discard as burn-in

init_num_clustersuint32 +

+ Number of clusters to initialize the algorithm. It may be + overridden by conditional mixings for which the number of + components is fixed (e.g. TruncatedSBMixing). In this case, this + value is ignored. +

+
neal8_n_auxuint32 +

Number of auxiliary unique values for the Neal8 algorithm

+
splitmerge_n_restr_gs_updatesuint32

Number of restricted GS scans for each MH step.

splitmerge_n_mh_updatesuint32 +

+ Number of MH updates for each iteration of Split and Merge + algorithm. +

+
splitmerge_n_full_gs_updatesuint32 +

+ Number of full GS scans for each iteration of Split and Merge + algorithm. +

+
+ +
+

algorithm_state.proto

+ Top +
+

+ +

AlgorithmState

+

This message represents the state of a Gibbs sampler for

+

a mixture model. All algorithms must be able to handle this

+

message, by filling it with the current state of the sampler

+

in the `get_state_as_proto` method.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
cluster_states + AlgorithmState.ClusterState + repeated

The state of each cluster

cluster_allocsint32repeated +

Vector of allocations into clusters, one for each observation

+
mixing_stateMixingState

The state of the `Mixing`

iteration_numint32

The iteration number

hierarchy_hypers + AlgorithmState.HierarchyHypers + +

The current values of the hyperparameters of the hierarchy

+
+ +

+ AlgorithmState.ClusterState +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
uni_ls_stateUniLSState

State of a univariate location-scale family

multi_ls_stateMultiLSState

State of a multivariate location-scale family

lin_reg_uni_ls_stateLinRegUniLSState +

State of a linear regression univariate location-scale family

+
general_stateVector

Just a vector of doubles

fa_stateFAState

State of a Mixture of Factor Analysers

cardinalityint32

How many observations are in this cluster

+ +

+ AlgorithmState.HierarchyHypers +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
general_stateVector

nnig_stateNIGDistribution

nnw_stateNWDistribution

lin_reg_uni_state + MultiNormalIGDistribution +

nnxig_stateNxIGDistribution

fa_state + FAPriorDistribution +

+ +
+

distribution.proto

+ Top +
+

+ +

BetaDistribution

+

Parameters defining a beta distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shape_adouble

shape_bdouble

+ +

GammaDistribution

+

Parameters defining a gamma distribution with density

+

f(x) = x^(shape-1) * exp(-rate * x) / Gamma(shape)

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shapedouble

ratedouble

+ +

InvWishartDistribution

+

Parameters defining an Inverse Wishart distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
deg_freedouble

scaleMatrix

+ +

MultiNormalDistribution

+

Parameters defining a multivariate normal distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

varMatrix

+ +

MultiNormalIGDistribution

+

+ Parameters for the Normal Inverse Gamma distribution commonly employed in +

+

linear regression models, with density

+

+ f(beta, var) = N(beta | mean, var * var_scaling^{-1}) * IG(var | shape, + scale) +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

var_scalingMatrix

shapedouble

scaledouble

+ +

NIGDistribution

+

Parameters of a Normal Inverse-Gamma distribution

+

with density

+

f(x, y) = N(x | mu, y/var_scaling) * IG(y | shape, scale)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

var_scalingdouble

shapedouble

scaledouble

+ +

NWDistribution

+

Parameters of a Normal Wishart distribution

+

with density

+

f(x, y) = N(x | mu, (y * var_scaling)^{-1}) * IW(y | deg_free, scale)

+

where x is a vector and y is a matrix (spd)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

var_scalingdouble

deg_freedouble

scaleMatrix

+ +

NxIGDistribution

+

Parameters of a Normal x Inverse-Gamma distribution

+

with density

+

f(x, y) = N(x | mu, var) * IG(y | shape, scale)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

+ +

UniNormalDistribution

+

Parameters defining a univariate normal distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

+ +
+

hierarchy_id.proto

+ Top +
+

+ +

HierarchyId

+

Enum for the different types of Hierarchy.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_HIERARCHY0

NNIG1

Normal - Normal Inverse Gamma

NNW2

Normal - Normal Wishart

LinRegUni3

Linear Regression (univariate response)

LapNIG4

Laplace - Normal Inverse Gamma

FA5

Factor Analysers

NNxIG6

Normal - Normal x Inverse Gamma

+ +
+

hierarchy_prior.proto

+ Top +
+

+ +

EmptyPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fake_fielddouble

+ +

FAPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_values + FAPriorDistribution +

+ +

FAPriorDistribution

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mutildeVector

betaVector

phidouble

alpha0double

quint32

+ +

LapNIGPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesLapNIGState

+ +

LapNIGState

+

+ Prior for the parameters of the base measure in a Laplace - Normal Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

mh_mean_vardouble

mh_log_scale_vardouble

+ +

LinRegUniPrior

+

+ Prior for the parameters of the base measure in a Normal mixture model + with a covariate-dependent +

+

location.

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_values + MultiNormalIGDistribution +

+ +

NNIGPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNIGDistribution

no prior, just fixed values

normal_mean_prior + NNIGPrior.NormalMeanPrior +

prior on the mean

ngg_priorNNIGPrior.NGGPrior

prior on the mean, var_scaling, and scale

+ +

NNIGPrior.NGGPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + UniNormalDistribution +

var_scaling_priorGammaDistribution

shapedouble

scale_priorGammaDistribution

+ +

NNIGPrior.NormalMeanPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + UniNormalDistribution +

var_scalingdouble

shapedouble

scaledouble

+ +

NNWPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal Wishart + hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNWDistribution

no prior, just fixed values

normal_mean_prior + NNWPrior.NormalMeanPrior +

prior on the mean

ngiw_priorNNWPrior.NGIWPrior

prior on the mean, var_scaling, and scale

+ +

NNWPrior.NGIWPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + MultiNormalDistribution +

var_scaling_priorGammaDistribution

deg_freedouble

scale_prior + InvWishartDistribution +

+ +

NNWPrior.NormalMeanPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + MultiNormalDistribution +

var_scalingdouble

deg_freedouble

scaleMatrix

+ +

NNxIGPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal x Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNxIGDistribution

no prior, just fixed values

+ +
+

ls_state.proto

+ Top +
+

+ +

FAState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
muVector

psiVector

etaMatrix

lambdaMatrix

+ +

LinRegUniLSState

+

Parameters of a univariate linear regression

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
regression_coeffsVector

regression coefficients

vardouble

variance of the noise

+ +

MultiLSState

+

Parameters of a multivariate location-scale family of distributions,

+

parameterized by mean and precision (inverse of variance). For

+

+ convenience, we also store the Cholesky factor of the precision matrix. +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

precMatrix

prec_cholMatrix

+ +

UniLSState

+

Parameters of a univariate location-scale family of distributions.

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

+ +
+

matrix.proto

+ Top +
+

+ +

Matrix

+

Message representing a matrix of doubles.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
rowsint32

number of rows

colsint32

number of columns

datadoublerepeated

matrix elements

rowmajorbool

if true, the data is read in row-major order

+ +

Vector

+

Message representing a vector of doubles.

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
sizeint32

number of elements in the vector

datadoublerepeated

vector elements

+ +
+

mixing_id.proto

+ Top +
+

+ +

MixingId

+

Enum for the different types of Mixing.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_MIXING0

DP1

Dirichlet Process

PY2

Pitman-Yor Process

LogSB3

Logit Stick-Breaking Process

TruncSB4

Truncated Stick-Breaking Process

MFM5

Mixture of finite mixtures

+ +
+

mixing_prior.proto

+ Top +
+

+ +

DPPrior

+

Prior for the concentration parameter of a Dirichlet process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valueDPState

No prior, just a fixed value

gamma_priorDPPrior.GammaPrior

Gamma prior on the total mass

+ +

DPPrior.GammaPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmass_priorGammaDistribution

+ +

LogSBPrior

+

Definition of the parameters of a Logit-Stick Breaking process.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
normal_prior + MultiNormalDistribution +

Normal prior on the regression coefficients

step_sizedouble +

+ Steps size for the MALA algorithm used for posterior inference + (TODO: move?) +

+
num_componentsuint32

Number of components in the process

+ +

MFMPrior

+

+ Prior for the Poisson rate and Dirichlet parameters of a MFM (Finite + Dirichlet) process. +

+

For the moment, we only support fixed values

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valueMFMState

No prior, just a fixed value

+ +

PYPrior

+

+ Prior for the strength and discount parameters of a Pitman-Yor process. +

+

For the moment, we only support fixed values

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesPYState

+ +

TruncSBPrior

+

Definition of the parameters of a truncated Stick-Breaking process

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
beta_priors + TruncSBPrior.BetaPriors +

General stick-breaking distributions

dp_prior + TruncSBPrior.DPPrior +

Truncated Dirichlet process

py_prior + TruncSBPrior.PYPrior +

Truncated Pitman-Yor process

mfm_prior + TruncSBPrior.MFMPrior +

num_componentsuint32

Number of components in the process

+ +

TruncSBPrior.BetaPriors

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
beta_distributionsBetaDistributionrepeated

General stick-breaking distributions

+ +

TruncSBPrior.DPPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

+ +

TruncSBPrior.MFMPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

+ +

TruncSBPrior.PYPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
strengthdouble

Truncated Pitman-Yor process

discountdouble

+ +
+

mixing_state.proto

+ Top +
+

+ +

DPState

+

State of a Dirichlet process

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

the total mass of the DP

+ +

LogSBState

+

State of a Logit-Stick Breaking process

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
regression_coeffsMatrix +

+ Num_Components x Num_Features matrix. Each row is the regression + coefficients for a component. +

+
+ +

MFMState

+

State of a MFM (Finite Dirichlet) process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
lambdadouble +

+ rate parameter of Poisson prior on number of compunents of the MFM +

+
gammadouble +

+ parameter of the dirichlet distribution for the mixing weights +

+
+ +

MixingState

+

Wrapper of all possible mixing states into a single oneof

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
dp_stateDPState

py_statePYState

log_sb_stateLogSBState

trunc_sb_stateTruncSBState

mfm_stateMFMState

+ +

PYState

+

State of a Pitman-Yor process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
strengthdouble

discountdouble

+ +

TruncSBState

+

+ State of a truncated sitck breaking process. For convenice we store also + the logarithm of the weights +

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
sticksVector

logweightsVector

+ +
+

mixture_model.proto

+ Top +
+

+ +

HierarchyPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
nnig_priorNNIGPrior

lapnig_priorLapNIGPrior

nnw_priorNNWPrior

lin_reg_priorLinRegUniPrior

fa_priorFAPrior

+ +

MixingPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
dp_priorDPPrior

py_priorPYPrior

log_sb_priorLogSBPrior

trunc_sb_priorTruncSBPrior

+ +

MixtureModel

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mixingMixingId

hierarchyHierarchyId

mixing_priorMixingPrior

hierarchy_priorHierarchyPrior

+ +
+

semihdp.proto

+ Top +
+

+ +

SemiHdpParams

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
pseudo_prior + SemiHdpParams.PseudoPriorParams +

dirichlet_concentrationdouble

rest_allocs_updatestring +

+ Either "full", "metro_base", "metro_dist" +

+
totalmass_restdouble

totalmass_hdpdouble

w_prior + SemiHdpParams.WPriorParams +

+ +

+ SemiHdpParams.PseudoPriorParams +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
card_weightdouble

mean_perturb_sddouble

var_perturb_fracdouble

+ +

SemiHdpParams.WPriorParams

+

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shape1double

shape2double

+ +

SemiHdpState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
restaurants + SemiHdpState.RestaurantState + repeated

groups + SemiHdpState.GroupState + repeated

taus + SemiHdpState.ClusterState + repeated

cint32repeated

wdouble

+ +

SemiHdpState.ClusterState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
uni_ls_stateUniLSState

multi_ls_stateMultiLSState

lin_reg_uni_ls_stateLinRegUniLSState

cardinalityint32

+ +

SemiHdpState.GroupState

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
cluster_allocsint32repeated

+ +

+ SemiHdpState.RestaurantState +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
theta_stars + SemiHdpState.ClusterState + repeated

n_by_clusint32repeated

table_to_sharedint32repeated

table_to_idioint32repeated

Scalar Value Types

- + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
.proto TypeNotesC++JavaPythonGoC#PHPRuby
.proto TypeNotesC++JavaPythonGoC#PHPRuby
doubledoubledoublefloatfloat64doublefloatFloat
floatfloatfloatfloatfloat32floatfloatFloat
int32Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead.int32intintint32intintegerBignum or Fixnum (as required)
int64Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead.int64longint/longint64longinteger/stringBignum
uint32Uses variable-length encoding.uint32intint/longuint32uintintegerBignum or Fixnum (as required)
uint64Uses variable-length encoding.uint64longint/longuint64ulonginteger/stringBignum or Fixnum (as required)
sint32Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s.int32intintint32intintegerBignum or Fixnum (as required)
sint64Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s.int64longint/longint64longinteger/stringBignum
fixed32Always four bytes. More efficient than uint32 if values are often greater than 2^28.uint32intintuint32uintintegerBignum or Fixnum (as required)
fixed64Always eight bytes. More efficient than uint64 if values are often greater than 2^56.uint64longint/longuint64ulonginteger/stringBignum
sfixed32Always four bytes.int32intintint32intintegerBignum or Fixnum (as required)
sfixed64Always eight bytes.int64longint/longint64longinteger/stringBignum
boolboolbooleanbooleanboolboolbooleanTrueClass/FalseClass
stringA string must always contain UTF-8 encoded or 7-bit ASCII text.stringStringstr/unicodestringstringstringString (UTF-8)
bytesMay contain any arbitrary sequence of bytes.stringByteStringstr[]byteByteStringstringString (ASCII-8BIT)
doubledoubledoublefloatfloat64doublefloatFloat
floatfloatfloatfloatfloat32floatfloatFloat
int32 + Uses variable-length encoding. Inefficient for encoding negative + numbers – if your field is likely to have negative values, use + sint32 instead. + int32intintint32intintegerBignum or Fixnum (as required)
int64 + Uses variable-length encoding. Inefficient for encoding negative + numbers – if your field is likely to have negative values, use + sint64 instead. + int64longint/longint64longinteger/stringBignum
uint32Uses variable-length encoding.uint32intint/longuint32uintintegerBignum or Fixnum (as required)
uint64Uses variable-length encoding.uint64longint/longuint64ulonginteger/stringBignum or Fixnum (as required)
sint32 + Uses variable-length encoding. Signed int value. These more + efficiently encode negative numbers than regular int32s. + int32intintint32intintegerBignum or Fixnum (as required)
sint64 + Uses variable-length encoding. Signed int value. These more + efficiently encode negative numbers than regular int64s. + int64longint/longint64longinteger/stringBignum
fixed32 + Always four bytes. More efficient than uint32 if values are often + greater than 2^28. + uint32intintuint32uintintegerBignum or Fixnum (as required)
fixed64 + Always eight bytes. More efficient than uint64 if values are often + greater than 2^56. + uint64longint/longuint64ulonginteger/stringBignum
sfixed32Always four bytes.int32intintint32intintegerBignum or Fixnum (as required)
sfixed64Always eight bytes.int64longint/longint64longinteger/stringBignum
boolboolbooleanbooleanboolboolbooleanTrueClass/FalseClass
string + A string must always contain UTF-8 encoded or 7-bit ASCII text. + stringStringstr/unicodestringstringstringString (UTF-8)
bytesMay contain any arbitrary sequence of bytes.stringByteStringstr[]byteByteStringstringString (ASCII-8BIT)
- diff --git a/docs/states.rst b/docs/states.rst new file mode 100644 index 000000000..68de1feb1 --- /dev/null +++ b/docs/states.rst @@ -0,0 +1,40 @@ +bayesmix/hierarchies/likelihoods/states + +States +====== + +``States`` are classes used to store parameters :math:`\theta_h` of every mixture component. +Their main purpose is to handle serialization and de-serialization of the state. +Moreover, they allow to go from the constrained to the unconstrained representation of the parameters (and viceversa) and compute the associated determinant of the Jacobian appearing in the change of density formula. + + +-------------- +Code Structure +-------------- + +All classes must inherit from the ``BaseState`` class + +.. doxygenclass:: State::BaseState + :project: bayesmix + :members: + +Depending on the chosen ``Updater``, the unconstrained representation might not be needed, and the methods ``get_unconstrained()``, ``set_from_unconstrained()`` and ``log_det_jac()`` might never be called. +Therefore, we do not force users to implement them. +Instead, the ``set_from_proto()`` and ``get_as_proto()`` are fundamental as they allow the interaction with Google's Protocol Buffers library. + +------------- +State Classes +------------- + +.. doxygenclass:: State::UniLS + :project: bayesmix + :members: + +.. doxygenclass:: State::MultiLS + :project: bayesmix + :members: + +.. doxygenclass:: State::FA + :project: bayesmix + :members: + :protected-members: diff --git a/docs/updaters.rst b/docs/updaters.rst new file mode 100644 index 000000000..ac3fb19c0 --- /dev/null +++ b/docs/updaters.rst @@ -0,0 +1,76 @@ +bayesmix/hierarchies/updaters + +Updaters +======== + +An ``Updater`` implements the machinery to provide a sampling from the full conditional distribution of a given hierarchy. + +The only operation performed is ``draw`` that samples from the full conditional, either exactly or via Markov chain Monte Carlo. + +.. doxygenclass:: AbstractUpdater + :project: bayesmix + :members: + +-------------- +Code Structure +-------------- + +We distinguish between semi-conjugate updaters and the metropolis-like updaters. + + +Semi Conjugate Updaters +----------------------- + +A semi-conjugate updater can be used when the full conditional distribution has the same form of the prior. Therefore, to sample from the full conditional, it is sufficient to call the ``draw`` method of the prior, but with an updated set of hyperparameters. + +The class ``SemiConjugateUpdater`` defines the API + +.. doxygenclass:: SemiConjugateUpdater + :project: bayesmix + :members: + +Classes inheriting from this one should only implement the ``compute_posterior_hypers(...)`` member function. + + +Metropolis-like Updaters +------------------------ + +A Metropolis updater uses the Metropolis-Hastings algorithm (or its variations) to sample from the full conditional density. + +.. doxygenclass:: MetropolisUpdater + :project: bayesmix + :members: + + +Classes inheriting from this one should only implement the ``sample_proposal(...)`` method, which samples from the porposal distribution, and the ``proposal_lpdf`` one, which evaluates the proposal density log-probability density function. + +--------------- +Updater Classes +--------------- + +.. doxygenclass:: RandomWalkUpdater + :project: bayesmix + :members: +.. doxygenclass:: MalaUpdater + :project: bayesmix + :members: +.. doxygenclass:: NNIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NNxIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NNWUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: MNIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: FAUpdater + :project: bayesmix + :members: + :protected-members: diff --git a/docs/utils.rst b/docs/utils.rst index 89c7e4533..45420f464 100644 --- a/docs/utils.rst +++ b/docs/utils.rst @@ -18,17 +18,20 @@ Distribution-related utilities :project: bayesmix ---------------------------------------------- -``Eigen`` input-output and matrix manipulation +``Eigen`` matrix manipulation utilities ---------------------------------------------- .. doxygenfile:: eigen_utils.h :project: bayesmix - :members: + +-------------------------------- +``Eigen`` input-output utilities +-------------------------------- .. doxygenfile:: io_utils.h :project: bayesmix -------------------------- -``protobuf`` input-output -------------------------- +----------------------------------- +``protobuf`` input-output utilities +----------------------------------- .. doxygenfile:: proto_utils.h :project: bayesmix diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 733662a26..09934a9f0 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,9 +1,12 @@ cmake_minimum_required(VERSION 3.13.0) project(examples_bayesmix) -add_executable(run_gamma $ +add_executable(run_gamma $ gamma_hierarchy/run_gamma_gamma.cc - gamma_hierarchy/gamma_gamma_hier.h + gamma_hierarchy/gammagamma_hierarchy.h + gamma_hierarchy/gamma_likelihood.h + gamma_hierarchy/gamma_prior_model.h + gamma_hierarchy/gammagamma_updater.h ) target_include_directories(run_gamma PUBLIC ${INCLUDE_PATHS}) diff --git a/examples/gamma_hierarchy/gamma_gamma_hier.h b/examples/gamma_hierarchy/gamma_gamma_hier.h deleted file mode 100644 index ecd43325b..000000000 --- a/examples/gamma_hierarchy/gamma_gamma_hier.h +++ /dev/null @@ -1,140 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_ - -#include -#include -#include - -#include -#include -#include - -#include "hierarchy_prior.pb.h" - -namespace GammaGamma { -//! Custom container for State values -struct State { - double rate; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double shape, rate_alpha, rate_beta; -}; -}; // namespace GammaGamma - -class GammaGammaHierarchy - : public ConjugateHierarchy { - public: - GammaGammaHierarchy(const double shape, const double rate_alpha, - const double rate_beta) - : shape(shape), rate_alpha(rate_alpha), rate_beta(rate_beta) { - create_empty_prior(); - } - ~GammaGammaHierarchy() = default; - - double like_lpdf(const Eigen::RowVectorXd &datum) const override { - return stan::math::gamma_lpdf(datum(0), hypers->shape, state.rate); - } - - double marg_lpdf( - const GammaGamma::Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - throw std::runtime_error("marg_lpdf() not implemented"); - return 0; - } - - GammaGamma::State draw(const GammaGamma::Hyperparams ¶ms) { - return GammaGamma::State{stan::math::gamma_rng( - params.rate_alpha, params.rate_beta, bayesmix::Rng::Instance().get())}; - } - - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) { - if (add) { - data_sum += datum(0); - ndata += 1; - } else { - data_sum -= datum(0); - ndata -= 1; - } - } - - //! Computes and return posterior hypers given data currently in this cluster - GammaGamma::Hyperparams compute_posterior_hypers() { - GammaGamma::Hyperparams out; - out.shape = hypers->shape; - out.rate_alpha = hypers->rate_alpha + hypers->shape * ndata; - out.rate_beta = hypers->rate_beta + data_sum; - return out; - } - - void initialize_state() override { - state.rate = hypers->rate_alpha / hypers->rate_beta; - } - - void initialize_hypers() { - hypers->shape = shape; - hypers->rate_alpha = rate_alpha; - hypers->rate_beta = rate_beta; - } - - //! Removes every data point from this cluster - void clear_summary_statistics() { - data_sum = 0; - ndata = 0; - } - - bool is_multivariate() const override { return false; } - - void set_state_from_proto(const google::protobuf::Message &state_) override { - auto &statecast = google::protobuf::internal::down_cast< - const bayesmix::AlgorithmState::ClusterState &>(state_); - state.rate = statecast.general_state().data()[0]; - set_card(statecast.cardinality()); - } - - std::shared_ptr get_state_proto() - const override { - bayesmix::Vector state_; - state_.mutable_data()->Add(state.rate); - - auto out = std::make_unique(); - out->mutable_general_state()->CopyFrom(state_); - return out; - } - - void update_hypers(const std::vector - &states) override { - return; - } - - void write_hypers_to_proto( - google::protobuf::Message *const out) const override { - return; - } - - void set_hypers_from_proto( - const google::protobuf::Message &state_) override { - return; - } - - std::shared_ptr get_hypers_proto() - const override { - return nullptr; - } - - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::UNKNOWN_HIERARCHY; - } - - protected: - double data_sum = 0; - int ndata = 0; - - double shape, rate_alpha, rate_beta; -}; - -#endif // BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_ diff --git a/examples/gamma_hierarchy/gamma_likelihood.h b/examples/gamma_hierarchy/gamma_likelihood.h new file mode 100644 index 000000000..25b3103c5 --- /dev/null +++ b/examples/gamma_hierarchy/gamma_likelihood.h @@ -0,0 +1,82 @@ +#ifndef BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "src/hierarchies/likelihoods/base_likelihood.h" +#include "src/hierarchies/likelihoods/states/base_state.h" + +namespace State { +class Gamma : public BaseState { + public: + double shape, rate; + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + ProtoState get_as_proto() const override { + ProtoState out; + out.mutable_general_state()->set_size(2); + out.mutable_general_state()->mutable_data()->Add(shape); + out.mutable_general_state()->mutable_data()->Add(rate); + return out; + } + + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + shape = state_.general_state().data()[0]; + rate = state_.general_state().data()[1]; + } +}; +} // namespace State + +class GammaLikelihood : public BaseLikelihood { + public: + GammaLikelihood() = default; + ~GammaLikelihood() = default; + bool is_multivariate() const override { return false; }; + bool is_dependent() const override { return false; }; + void clear_summary_statistics() override; + + // Getters and Setters + int get_ndata() const { return ndata; }; + double get_shape() const { return state.shape; }; + double get_data_sum() const { return data_sum; }; + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + + //! Sum of data in the cluster + double data_sum = 0; + //! number of data in the cluster + int ndata = 0; +}; + +/* DEFINITIONS */ +void GammaLikelihood::clear_summary_statistics() { + data_sum = 0; + ndata = 0; +} + +double GammaLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::gamma_lpdf(datum(0), state.shape, state.rate); +} + +void GammaLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + bool add) { + if (add) { + data_sum += datum(0); + ndata += 1; + } else { + data_sum -= datum(0); + ndata -= 1; + } +} + +#endif // BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_ diff --git a/examples/gamma_hierarchy/gamma_prior_model.h b/examples/gamma_hierarchy/gamma_prior_model.h new file mode 100644 index 000000000..a49bafadc --- /dev/null +++ b/examples/gamma_hierarchy/gamma_prior_model.h @@ -0,0 +1,107 @@ +#ifndef BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_ + +#include +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "gamma_likelihood.h" +#include "hierarchy_prior.pb.h" +#include "src/hierarchies/priors/base_prior_model.h" +#include "src/utils/rng.h" + +namespace Hyperparams { +struct Gamma { + double rate_alpha, rate_beta; +}; +} // namespace Hyperparams + +class GammaPriorModel + : public BasePriorModel { + public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + + GammaPriorModel(double shape_ = -1, double rate_alpha_ = -1, + double rate_beta_ = -1); + ~GammaPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + State::Gamma sample(ProtoHypersPtr hier_hypers = nullptr) override; + + void update_hypers(const std::vector + &states) override { + return; + }; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + ProtoHypersPtr get_hypers_proto() const override; + double get_shape() const { return shape; }; + + protected: + double shape, rate_alpha, rate_beta; + void initialize_hypers() override; +}; + +/* DEFINITIONS */ +GammaPriorModel::GammaPriorModel(double shape_, double rate_alpha_, + double rate_beta_) + : shape(shape_), rate_alpha(rate_alpha_), rate_beta(rate_beta_) { + create_empty_prior(); +}; + +double GammaPriorModel::lpdf(const google::protobuf::Message &state_) { + double rate = downcast_state(state_).general_state().data()[1]; + return stan::math::gamma_lpdf(rate, hypers->rate_alpha, hypers->rate_beta); +} + +State::Gamma GammaPriorModel::sample(ProtoHypersPtr hier_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + State::Gamma out; + + auto params = (hier_hypers) ? hier_hypers->general_state() + : get_hypers_proto()->general_state(); + double rate_alpha = params.data()[0]; + double rate_beta = params.data()[1]; + out.shape = shape; + out.rate = stan::math::gamma_rng(rate_alpha, rate_beta, rng); + return out; +} + +void GammaPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).general_state(); + hypers->rate_alpha = hyperscast.data()[0]; + hypers->rate_beta = hyperscast.data()[1]; +}; + +GammaPriorModel::ProtoHypersPtr GammaPriorModel::get_hypers_proto() const { + ProtoHypersPtr out = std::make_shared(); + out->mutable_general_state()->mutable_data()->Add(hypers->rate_alpha); + out->mutable_general_state()->mutable_data()->Add(hypers->rate_beta); + return out; +}; + +void GammaPriorModel::initialize_hypers() { + hypers->rate_alpha = rate_alpha; + hypers->rate_beta = rate_beta; + + // Checks + if (shape <= 0) { + throw std::runtime_error("shape must be positive"); + } + if (rate_alpha <= 0) { + throw std::runtime_error("rate_alpha must be positive"); + } + if (rate_beta <= 0) { + throw std::runtime_error("rate_beta must be positive"); + } +} + +#endif // BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_ diff --git a/examples/gamma_hierarchy/gammagamma_hierarchy.h b/examples/gamma_hierarchy/gammagamma_hierarchy.h new file mode 100644 index 000000000..6fe135533 --- /dev/null +++ b/examples/gamma_hierarchy/gammagamma_hierarchy.h @@ -0,0 +1,47 @@ +#ifndef BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_ + +#include "gamma_likelihood.h" +#include "gamma_prior_model.h" +#include "gammagamma_updater.h" +#include "hierarchy_id.pb.h" +#include "src/hierarchies/base_hierarchy.h" + +class GammaGammaHierarchy + : public BaseHierarchy { + public: + GammaGammaHierarchy(double shape_, double rate_alpha_, double rate_beta_) { + auto prior = + std::make_shared(shape_, rate_alpha_, rate_beta_); + set_prior(prior); + }; + ~GammaGammaHierarchy() = default; + + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::UNKNOWN_HIERARCHY; + } + + void set_default_updater() { + updater = std::make_shared(); + } + + void initialize_state() override { + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::Gamma state; + state.shape = prior->get_shape(); + state.rate = hypers.rate_alpha / hypers.rate_beta; + like->set_state(state); + }; + + double marg_lpdf(ProtoHypersPtr hier_params, + const Eigen::RowVectorXd &datum) const override { + throw( + std::runtime_error("marg_lpdf() not implemented for this hierarchy")); + return 0; + } +}; + +#endif // BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_ diff --git a/examples/gamma_hierarchy/gammagamma_updater.h b/examples/gamma_hierarchy/gammagamma_updater.h new file mode 100644 index 000000000..dd05b6ffd --- /dev/null +++ b/examples/gamma_hierarchy/gammagamma_updater.h @@ -0,0 +1,49 @@ +#ifndef BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_ + +#include "gamma_likelihood.h" +#include "gamma_prior_model.h" +#include "src/hierarchies/updaters/semi_conjugate_updater.h" + +class GammaGammaUpdater + : public SemiConjugateUpdater { + public: + GammaGammaUpdater() = default; + ~GammaGammaUpdater() = default; + + bool is_conjugate() const override { return true; }; + + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) override; +}; + +/* DEFINITIONS */ +AbstractUpdater::ProtoHypersPtr GammaGammaUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + int card = likecast.get_card(); + double data_sum = likecast.get_data_sum(); + double ndata = likecast.get_ndata(); + double shape = priorcast.get_shape(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + return priorcast.get_hypers_proto(); + } + // Compute posterior hyperparameters + double rate_alpha_new = hypers.rate_alpha + shape * ndata; + double rate_beta_new = hypers.rate_beta + data_sum; + + // Proto conversion + ProtoHypers out; + out.mutable_general_state()->mutable_data()->Add(rate_alpha_new); + out.mutable_general_state()->mutable_data()->Add(rate_beta_new); + return std::make_shared(out); +} + +#endif // BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_ diff --git a/examples/gamma_hierarchy/run_gamma_gamma.cc b/examples/gamma_hierarchy/run_gamma_gamma.cc index f6502b72d..d8e5cf0f5 100644 --- a/examples/gamma_hierarchy/run_gamma_gamma.cc +++ b/examples/gamma_hierarchy/run_gamma_gamma.cc @@ -1,6 +1,6 @@ #include -#include "gamma_gamma_hier.h" +#include "gammagamma_hierarchy.h" #include "src/includes.h" Eigen::MatrixXd simulate_data(const unsigned int ndata) { diff --git a/executables/run_mcmc.cc b/executables/run_mcmc.cc index eccc423c3..8a47aabb8 100644 --- a/executables/run_mcmc.cc +++ b/executables/run_mcmc.cc @@ -167,6 +167,7 @@ int main(int argc, char *argv[]) { mixing->get_mutable_prior()); bayesmix::read_proto_from_file(args.get("--hier-args"), hier->get_mutable_prior()); + hier->initialize(); // Read data matrices Eigen::MatrixXd data = diff --git a/python/notebooks/gaussian_mix_NNxIG.ipynb b/python/notebooks/gaussian_mix_NNxIG.ipynb new file mode 100644 index 000000000..96e3fa370 --- /dev/null +++ b/python/notebooks/gaussian_mix_NNxIG.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6c73fa6a", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from bayesmixpy import run_mcmc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79825dfc", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"BAYESMIX_EXE\"] = \"../../build/run_mcmc\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64a83071", + "metadata": {}, + "outputs": [], + "source": [ + "# Generate data\n", + "data = np.concatenate([\n", + " np.random.normal(loc=3, scale=1, size=100),\n", + " np.random.normal(loc=-3, scale=1, size=100),\n", + "])\n", + "\n", + "# Plot data\n", + "plt.hist(data)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13df394d", + "metadata": {}, + "outputs": [], + "source": [ + "# Hierarchy settings\n", + "hier_params = \\\n", + "\"\"\"\n", + "fixed_values {\n", + " mean: 0.0\n", + " var: 10.0\n", + " shape: 2.0\n", + " scale: 2.0\n", + "}\n", + "\"\"\"\n", + "\n", + "# Mixing settings\n", + "mix_params = \\\n", + "\"\"\"\n", + "fixed_value {\n", + " totalmass: 1.0\n", + "}\n", + "\"\"\"\n", + "\n", + "# Algorithm settings\n", + "algo_params = \\\n", + "\"\"\"\n", + "algo_id: \"Neal8\"\n", + "rng_seed: 20201124\n", + "iterations: 2000\n", + "burnin: 1000\n", + "init_num_clusters: 3\n", + "neal8_n_aux: 3\n", + "\"\"\"\n", + "\n", + "# Evaluation grid\n", + "dens_grid = np.linspace(-6.5, 6.5, 1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12505b6e", + "metadata": {}, + "outputs": [], + "source": [ + "# Fit model using bayesmixpy\n", + "eval_dens, n_clus, clus_chain, best_clus = run_mcmc(\"NNxIG\",\"DP\", data,\n", + " hier_params, mix_params, algo_params,\n", + " dens_grid, return_num_clusters=True,\n", + " return_clusters=True, return_best_clus=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8470fc0a", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n", + "\n", + "# Posterior distribution of clusters\n", + "x, y = np.unique(n_clus, return_counts=True)\n", + "axes[0].bar(x, y / y.sum())\n", + "axes[0].set_xticks(x)\n", + "axes[0].set_title(\"Posterior distribution of the number of clusters\")\n", + "\n", + "# Plot mean posterior density\n", + "axes[1].plot(dens_grid, np.exp(np.mean(eval_dens, axis=0)))\n", + "axes[1].hist(data, alpha=0.3, density=True)\n", + "for c in np.unique(best_clus):\n", + " data_in_clus = data[best_clus == c]\n", + " axes[1].scatter(data_in_clus, np.zeros_like(data_in_clus) + 0.01)\n", + "axes[1].set_title(\"Posterior density estimate\")\n", + "\n", + "# Show results\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/python/notebooks/gaussian_mix_uni.ipynb b/python/notebooks/gaussian_mix_uni.ipynb index 2c6f7f4f0..7a9c71ed1 100644 --- a/python/notebooks/gaussian_mix_uni.ipynb +++ b/python/notebooks/gaussian_mix_uni.ipynb @@ -210,24 +210,24 @@ "neal2_algo = \"\"\"\n", "algo_id: \"Neal2\"\n", "rng_seed: 20201124\n", - "iterations: 10\n", - "burnin: 5\n", + "iterations: 100\n", + "burnin: 50\n", "init_num_clusters: 3\n", "\"\"\"\n", "\n", "neal3_algo = \"\"\"\n", "algo_id: \"Neal3\"\n", "rng_seed: 20201124\n", - "iterations: 10\n", - "burnin: 5\n", + "iterations: 100\n", + "burnin: 50\n", "init_num_clusters: 3\n", "\"\"\"\n", "\n", "neal8_algo = \"\"\"\n", "algo_id: \"Neal8\"\n", "rng_seed: 20201124\n", - "iterations: 10\n", - "burnin: 5\n", + "iterations: 100\n", + "burnin: 50\n", "init_num_clusters: 3\n", "neal8_n_aux: 3\n", "\"\"\"\n", @@ -244,7 +244,7 @@ "\n", "`return_clusters=False, return_num_clusters=False, return_best_clus=False`\n", "\n", - "Observe that the number of iterations is extremely small! In real problems, you might want to set the burnin at least to 1000 iterations and the total number of iterations to at leas 2000." + "Observe that the number of iterations is extremely small! In real problems, you might want to set the burnin at least to 1000 iterations and the total number of iterations to at least 2000." ] }, { @@ -449,13 +449,6 @@ "source": [ "np.var(data)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -463,7 +456,7 @@ "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -477,7 +470,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.10" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/src/algorithms/base_algorithm.cc b/src/algorithms/base_algorithm.cc index 19597a406..bfe9ac76c 100644 --- a/src/algorithms/base_algorithm.cc +++ b/src/algorithms/base_algorithm.cc @@ -1,7 +1,7 @@ #include "base_algorithm.h" -#include #include +#include #include #include "algorithm_params.pb.h" diff --git a/src/algorithms/base_algorithm.h b/src/algorithms/base_algorithm.h index 988260f5e..b42df0611 100644 --- a/src/algorithms/base_algorithm.h +++ b/src/algorithms/base_algorithm.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "algorithm_id.pb.h" diff --git a/src/algorithms/conditional_algorithm.cc b/src/algorithms/conditional_algorithm.cc index 30090013a..aebc2cdf5 100644 --- a/src/algorithms/conditional_algorithm.cc +++ b/src/algorithms/conditional_algorithm.cc @@ -1,7 +1,7 @@ #include "conditional_algorithm.h" -#include #include +#include #include "algorithm_state.pb.h" #include "base_algorithm.h" diff --git a/src/algorithms/conditional_algorithm.h b/src/algorithms/conditional_algorithm.h index f99b01b1e..f7bed0391 100644 --- a/src/algorithms/conditional_algorithm.h +++ b/src/algorithms/conditional_algorithm.h @@ -1,34 +1,40 @@ #ifndef BAYESMIX_ALGORITHMS_CONDITIONAL_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_CONDITIONAL_ALGORITHM_H_ -#include #include +#include #include "base_algorithm.h" #include "src/collectors/base_collector.h" -//! Template class for a conditional sampler deriving from `BaseAlgorithm`. - -//! This template class implements a generic Gibbs sampling conditional -//! algorithm as the child of the `BaseAlgorithm` class. -//! A mixture model sampled from a conditional algorithm can be expressed as -//! x_i | c_i, phi_1, ..., phi_k ~ f(x_i|phi_(c_i)) (data likelihood); -//! phi_1, ... phi_k ~ G (unique values); -//! c_1, ... c_n | w_1, ..., w_k ~ Cat(w_1, ... w_k) (cluster allocations); -//! w_1, ..., w_k ~ p(w_1, ..., w_k) (mixture weights) -//! where f(x | phi_j) is a density for each value of phi_j, the c_i take -//! values in {1, ..., k} and w_1, ..., w_k are nonnegative weights whose sum -//! is a.s. 1, i.e. p(w_1, ... w_k) is a probability distribution on the k-1 -//! dimensional unit simplex). -//! In this library, each phi_j is represented as an `Hierarchy` object (which -//! inherits from `AbstractHierarchy`), that also holds the information related -//! to the base measure `G` is (see `AbstractHierarchy`). -//! The weights (w_1, ..., w_k) are represented as a `Mixing` object, which -//! inherits from `AbstractMixing`. - -//! The state of a conditional algorithm consists of the unique values, the -//! cluster allocations and the mixture weights. The former two are stored in -//! this class, while the weights are stored in the `Mixing` object. +/** + * Template class for a conditional sampler deriving from `BaseAlgorithm`. + * + * This template class implements a generic Gibbs sampling conditional + * algorithm as the child of the `BaseAlgorithm` class. + * A mixture model sampled from a conditional algorithm can be expressed as + * + * \f[ + * x_i \mid c_i, \theta_1, \dots, \theta_k &\sim f(x_i \mid \theta_{c_i}) \\ + * \theta_1, \dots, \theta_k &\sim G_0 \\ + * c_1, \dots, c_n \mid w_1, \dots, w_k &\sim \text{Cat}(w_1, \dots, w_k) \\ + * w_1, \dots, w_k &\sim p(w_1, \dots, w_k) + * \f] + * + * where \f$ f(x \mid \theta_j) \f$ is a density for each value of \f$ \theta_j + * \f$, \f$ c_i \f$ take values in \f$ \{1, \dots, k\} \f$ and \f$ w_1, \dots, + * w_k \f$ are nonnegative weights whose sum is a.s. 1, i.e. \f$ p(w_1, \dots, + * w_k) \f$ is a probability distribution on the k-1 dimensional unit simplex). + * In this library, each \f$ \theta_j \f$ is represented as an `Hierarchy` + * object (which inherits from `AbstractHierarchy`), that also holds the + * information related to the base measure \f$ G \f$ is (see + * `AbstractHierarchy`). The weights \f$ (w_1, \dots, w_k) \f$ are represented + * as a `Mixing` object, which inherits from `AbstractMixing`. + * + * The state of a conditional algorithm consists of the unique values, the + * cluster allocations and the mixture weights. The former two are stored in + * this class, while the weights are stored in the `Mixing` object. + */ class ConditionalAlgorithm : public BaseAlgorithm { public: diff --git a/src/algorithms/marginal_algorithm.cc b/src/algorithms/marginal_algorithm.cc index 279402c1a..ddf4340de 100644 --- a/src/algorithms/marginal_algorithm.cc +++ b/src/algorithms/marginal_algorithm.cc @@ -1,8 +1,8 @@ #include "marginal_algorithm.h" -#include #include #include +#include #include "algorithm_state.pb.h" #include "base_algorithm.h" diff --git a/src/algorithms/marginal_algorithm.h b/src/algorithms/marginal_algorithm.h index 44fd4164c..95fcddd92 100644 --- a/src/algorithms/marginal_algorithm.h +++ b/src/algorithms/marginal_algorithm.h @@ -1,36 +1,44 @@ #ifndef BAYESMIX_ALGORITHMS_MARGINAL_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_MARGINAL_ALGORITHM_H_ -#include #include +#include #include "base_algorithm.h" #include "src/collectors/base_collector.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Template class for a marginal sampler deriving from `BaseAlgorithm`. - -//! This template class implements a generic Gibbs sampling marginal algorithm -//! as the child of the `BaseAlgorithm` class. -//! A mixture model sampled from a Marginal Algorithm can be expressed as -//! x_i | c_i, phi_1, ..., phi_k ~ f(x_i|phi_(c_i)) (data likelihood); -//! phi_1, ... phi_k ~ G (unique values); -//! c_1, ... c_n ~ EPPF(c_1, ... c_n) (cluster allocations); -//! where f(x | phi_j) is a density for each value of phi_j and the c_i take -//! values in {1, ..., k}. -//! Depending on the actual implementation, the algorithm might require -//! the kernel/likelihood f(x | phi) and G(phi) to be conjugagte or not. -//! In the former case, a `ConjugateHierarchy` must be specified. -//! In this library, each phi_j is represented as an `Hierarchy` object (which -//! inherits from `AbstractHierarchy`), that also holds the information related -//! to the base measure `G` is (see `AbstractHierarchy`). The EPPF is instead -//! represented as a `Mixing` object, which inherits from `AbstractMixing`. -//! -//! The state of a marginal algorithm only consists of allocations and unique -//! values. In this class of algorithms, the local lpdf estimate for a single -//! iteration is a weighted average of likelihood values corresponding to each -//! component (i.e. cluster), with the weights being based on its cardinality, -//! and of the marginal component, which depends on the specific algorithm. +/** + * Template class for a marginal sampler deriving from `BaseAlgorithm`. + * + * This template class implements a generic Gibbs sampling marginal algorithm + * as the child of the `BaseAlgorithm` class. + * A mixture model sampled from a Marginal Algorithm can be expressed as + * + * \f[ + * x_i \mid c_i, \theta_1, \dots, \theta_k &\sim f(x_i \mid \theta_{c_i}) + * \\ + * \theta_1, \dots, \theta_k &\sim G_0 \\ + * c_1, \dots, c_n &\sim EPPF(c_1, \dots, c_n) + * \f] + * + * where \f$ f(x \mid \theta_j) \f$ is a density for each value of \f$ \theta_j + * \f$ and \f$ c_i \f$ take values in \f$ {1, \dots, k} \f$. Depending on the + * actual implementation, the algorithm might require the kernel/likelihood \f$ + * f(x \mid \theta) \f$ and \f$ G_0(\phi) \f$ to be conjugagte or not. In the + * former case, a conjugate hierarchy must be specified. In this library, each + * \f$ \theta_j \f$ is represented as an `Hierarchy` object (which inherits + * from `AbstractHierarchy`), that also holds the information related to the + * base measure \f$ G_0 \f$ is (see `AbstractHierarchy`). The \f$ EPPF \f$ is + * instead represented as a `Mixing` object, which inherits from + * `AbstractMixing`. + * + * The state of a marginal algorithm only consists of allocations and unique + * values. In this class of algorithms, the local lpdf estimate for a single + * iteration is a weighted average of likelihood values corresponding to each + * component (i.e. cluster), with the weights being based on its cardinality, + * and of the marginal component, which depends on the specific algorithm. + */ class MarginalAlgorithm : public BaseAlgorithm { public: @@ -45,9 +53,12 @@ class MarginalAlgorithm : public BaseAlgorithm { protected: //! Computes marginal contribution of the given cluster to the lpdf estimate - //! @param hier Pointer to the `Hierarchy` object representing the cluster - //! @param grid Grid of row points on which the density is to be evaluated - //! @return The marginal component of the estimate + //! @param hier Pointer to the `Hierarchy` object representing the + //! cluster + //! @param grid Grid of row points on which the density is to be + //! evaluated + //! @param covariate (Optional) covariate vectors associated to data + //! @return The marginal component of the estimate virtual Eigen::VectorXd lpdf_marginal_component( const std::shared_ptr hier, const Eigen::MatrixXd &grid, diff --git a/src/algorithms/neal2_algorithm.cc b/src/algorithms/neal2_algorithm.cc index 2973cc8dc..d67550c07 100644 --- a/src/algorithms/neal2_algorithm.cc +++ b/src/algorithms/neal2_algorithm.cc @@ -1,8 +1,8 @@ #include "neal2_algorithm.h" -#include #include #include +#include #include #include "algorithm_id.pb.h" diff --git a/src/algorithms/neal2_algorithm.h b/src/algorithms/neal2_algorithm.h index d9f63dc6a..a51e8504d 100644 --- a/src/algorithms/neal2_algorithm.h +++ b/src/algorithms/neal2_algorithm.h @@ -1,8 +1,8 @@ #ifndef BAYESMIX_ALGORITHMS_NEAL2_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_NEAL2_ALGORITHM_H_ -#include #include +#include #include "algorithm_id.pb.h" #include "marginal_algorithm.h" diff --git a/src/algorithms/neal3_algorithm.cc b/src/algorithms/neal3_algorithm.cc index 7e209b81f..e7bf9712a 100644 --- a/src/algorithms/neal3_algorithm.cc +++ b/src/algorithms/neal3_algorithm.cc @@ -1,7 +1,7 @@ #include "neal3_algorithm.h" -#include #include +#include #include "hierarchy_id.pb.h" #include "mixing_id.pb.h" diff --git a/src/algorithms/neal8_algorithm.cc b/src/algorithms/neal8_algorithm.cc index c38bc803c..6ee6d25fd 100644 --- a/src/algorithms/neal8_algorithm.cc +++ b/src/algorithms/neal8_algorithm.cc @@ -1,8 +1,8 @@ #include "neal8_algorithm.h" -#include #include #include +#include #include "algorithm_id.pb.h" #include "algorithm_state.pb.h" diff --git a/src/algorithms/neal8_algorithm.h b/src/algorithms/neal8_algorithm.h index 5ff026809..edbdb2e4a 100644 --- a/src/algorithms/neal8_algorithm.h +++ b/src/algorithms/neal8_algorithm.h @@ -1,8 +1,8 @@ #ifndef BAYESMIX_ALGORITHMS_NEAL8_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_NEAL8_ALGORITHM_H_ -#include #include +#include #include #include "algorithm_id.pb.h" diff --git a/src/algorithms/semihdp_sampler.h b/src/algorithms/semihdp_sampler.h index e684c6dd1..3558a3ae9 100644 --- a/src/algorithms/semihdp_sampler.h +++ b/src/algorithms/semihdp_sampler.h @@ -1,9 +1,9 @@ #ifndef SRC_ALGORITHMS_SEMIHDP_SAMPLER_H #define SRC_ALGORITHMS_SEMIHDP_SAMPLER_H -#include #include #include +#include #include #include diff --git a/src/algorithms/split_and_merge_algorithm.h b/src/algorithms/split_and_merge_algorithm.h index 43d595c53..61d5c0d7d 100644 --- a/src/algorithms/split_and_merge_algorithm.h +++ b/src/algorithms/split_and_merge_algorithm.h @@ -1,8 +1,8 @@ #ifndef BAYESMIX_ALGORITHMS_SPLIT_AND_MERGE_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_SPLIT_AND_MERGE_ALGORITHM_H_ -#include #include +#include #include "algorithm_id.pb.h" #include "marginal_algorithm.h" diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index 0207bd04f..42a45fd87 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -2,15 +2,14 @@ target_sources(bayesmix PUBLIC abstract_hierarchy.h base_hierarchy.h - conjugate_hierarchy.h - lin_reg_uni_hierarchy.h - lin_reg_uni_hierarchy.cc nnig_hierarchy.h - nnig_hierarchy.cc + nnxig_hierarchy.h nnw_hierarchy.h - nnw_hierarchy.cc + lin_reg_uni_hierarchy.h fa_hierarchy.h - fa_hierarchy.cc lapnig_hierarchy.h - lapnig_hierarchy.cc ) + +add_subdirectory(likelihoods) +add_subdirectory(priors) +add_subdirectory(updaters) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index b69716a0c..947f362cf 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -3,64 +3,88 @@ #include -#include #include #include #include #include +#include #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" +#include "src/hierarchies/updaters/abstract_updater.h" #include "src/utils/rng.h" -//! Abstract base class for a hierarchy object. -//! This class is the basis for a curiously recurring template pattern (CRTP) -//! for `Hierarchy` objects, and is solely composed of interface functions for -//! derived classes to use. For more information about this pattern, as well -//! the list of methods required for classes in this inheritance tree, please -//! refer to the README.md file included in this folder. - -//! This abstract class represents a Bayesian hierarchical model: -//! x_1, ..., x_n \sim f(x | \theta) -//! theta \sim G -//! A Hierarchy object can compute the following quantities: -//! 1- the likelihood log-probability density function -//! 2- the prior predictive probability: \int_\Theta f(x | theta) G(d\theta) -//! (for conjugate models only) -//! 3- the posterior predictive probability -//! \int_\Theta f(x | theta) G(d\theta | x_1, ..., x_n) -//! (for conjugate models only) -//! Moreover, the Hierarchy knows how to sample from the full conditional of -//! theta, possibly in an approximate way. -//! -//! In the context of our Gibbs samplers, an hierarchy represents the parameter -//! value associated to a certain cluster, and also knows which observations -//! are allocated to that cluster. -//! Moreover, hyperparameters and (possibly) hyperpriors associated to them can -//! be shared across multiple Hierarchies objects via a shared pointer. -//! In conjunction with a single `Mixing` object, a collection of `Hierarchy` -//! objects completely defines a mixture model, and these two parts can be -//! chosen independently of each other. -//! Communication with other classes, as well as storage of some relevant -//! values, is performed via appropriately defined Protobuf messages (see for -//! instance the proto/ls_state.proto and proto/hierarchy_prior.proto files) -//! and their relative class methods. +/** + * Abstract base class for a hierarchy object. + * This class is the basis for a curiously recurring template pattern (CRTP) + * for `Hierarchy` objects, and is solely composed of interface functions for + * derived classes to use. For more information about this pattern, as well + * the list of methods required for classes in this inheritance tree, please + * refer to the README.md file included in this folder. + * + * This abstract class represents a Bayesian hierarchical model: + * + * \f[ + * x_1,\dots,x_n &\sim f(x \mid \theta) \\ + * \theta &\sim G + * \f] + * + * A Hierarchy object can compute the following quantities: + * + * 1. the likelihood log-probability density function + * 2. the prior predictive probability: \f$ \int_\Theta f(x \mid \theta) + * G(d\theta) \f$ (for conjugate models only) + * 3. the posterior predictive probability + * \f$ \int_\Theta f(x \mid \theta) G(d\theta \mid x_1, ..., x_n) \f$ + * (for conjugate models only) + * + * Moreover, the Hierarchy knows how to sample from the full conditional of + * \f$ \theta \f$, possibly in an approximate way. + * + * In the context of our Gibbs samplers, an hierarchy represents the parameter + * value associated to a certain cluster, and also knows which observations + * are allocated to that cluster. + * + * Moreover, hyperparameters and (possibly) hyperpriors associated to them can + * be shared across multiple Hierarchies objects via a shared pointer. + * In conjunction with a single `Mixing` object, a collection of `Hierarchy` + * objects completely defines a mixture model, and these two parts can be + * chosen independently of each other. + * + * Communication with other classes, as well as storage of some relevant + * values, is performed via appropriately defined Protobuf messages (see for + * instance the `proto/ls_state.proto` and `proto/hierarchy_prior.proto` files) + * and their relative class methods. + */ class AbstractHierarchy { public: + //! Set the update algorithm for the current hierarchy + virtual void set_updater(std::shared_ptr updater_) = 0; + + //! Returns (a pointer to) the likelihood for the current hierarchy + virtual std::shared_ptr get_likelihood() = 0; + + //! Returns (a pointer to) the prior model for the current hierarchy + virtual std::shared_ptr get_prior() = 0; + + //! Default destructor virtual ~AbstractHierarchy() = default; //! Returns an independent, data-less copy of this object virtual std::shared_ptr clone() const = 0; + //! Returns an independent, data-less copy of this object virtual std::shared_ptr deep_clone() const = 0; // EVALUATION FUNCTIONS FOR SINGLE POINTS //! Public wrapper for `like_lpdf()` methods - double get_like_lpdf( + virtual double get_like_lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - if (is_dependent()) { + if (is_dependent() and covariate.size() != 0) { return like_lpdf(datum, covariate); } else { return like_lpdf(datum); @@ -74,8 +98,13 @@ class AbstractHierarchy { virtual double prior_pred_lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - throw std::runtime_error( - "Cannot call prior_pred_lpdf() from a non-conjugate hierarchy"); + if (is_conjugate()) { + throw std::runtime_error( + "prior_pred_lpdf() not implemented for this hierarchy"); + } else { + throw std::runtime_error( + "Cannot call prior_pred_lpdf() from a non-conjugate hierarchy"); + } } //! Evaluates the log-conditional predictive distr. of data in a single point @@ -85,8 +114,14 @@ class AbstractHierarchy { virtual double conditional_pred_lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - throw std::runtime_error( - "Cannot call conditional_pred_lpdf() from a non-conjugate hierarchy"); + if (is_conjugate()) { + throw std::runtime_error( + "conditional_pred_lpdf() not implemented for this hierarchy"); + } else { + throw std::runtime_error( + "Cannot call conditional_pred_lpdf() from a non-conjugate " + "hierarchy"); + } } // EVALUATION FUNCTIONS FOR GRIDS OF POINTS @@ -105,8 +140,13 @@ class AbstractHierarchy { virtual Eigen::VectorXd prior_pred_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const { - throw std::runtime_error( - "Cannot call prior_pred_lpdf_grid() from a non-conjugate hierarchy"); + if (is_conjugate()) { + throw std::runtime_error( + "prior_pred_lpdf_grid() not implemented for this hierarchy"); + } else { + throw std::runtime_error( + "Cannot call prior_pred_lpdf_grid() from a non-conjugate hierarchy"); + } } //! Evaluates the log-prior predictive distr. of data in a grid of points @@ -116,9 +156,14 @@ class AbstractHierarchy { virtual Eigen::VectorXd conditional_pred_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const { - throw std::runtime_error( - "Cannot call conditional_pred_lpdf_grid() from a non-conjugate " - "hierarchy"); + if (is_conjugate()) { + throw std::runtime_error( + "conditional_pred_lpdf_grid() not implemented for this hierarchy"); + } else { + throw std::runtime_error( + "Cannot call conditional_pred_lpdf_grid() from a non-conjugate " + "hierarchy"); + } } // SAMPLING FUNCTIONS @@ -208,12 +253,12 @@ class AbstractHierarchy { virtual bool is_multivariate() const = 0; //! Returns whether the hierarchy depends on covariate values or not - virtual bool is_dependent() const { return false; } + virtual bool is_dependent() const = 0; //! Returns whether the hierarchy represents a conjugate model or not - virtual bool is_conjugate() const { return false; } + virtual bool is_conjugate() const = 0; - //! Main function that initializes members to appropriate values + //! Sets the (pointer to) the dataset in the cluster virtual void set_dataset(const Eigen::MatrixXd *const dataset) = 0; protected: @@ -227,7 +272,8 @@ class AbstractHierarchy { throw std::runtime_error( "Cannot call like_lpdf() from a non-dependent hierarchy"); } else { - throw std::runtime_error("like_lpdf() not implemented"); + throw std::runtime_error( + "like_lpdf() not implemented for this hierarchy"); } } @@ -239,7 +285,8 @@ class AbstractHierarchy { throw std::runtime_error( "Cannot call like_lpdf() from a dependent hierarchy"); } else { - throw std::runtime_error("like_lpdf() not implemented"); + throw std::runtime_error( + "like_lpdf() not implemented for this hierarchy"); } } @@ -255,7 +302,8 @@ class AbstractHierarchy { "Cannot call update_summary_statistics() from a non-dependent " "hierarchy"); } else { - throw std::runtime_error("update_summary_statistics() not implemented"); + throw std::runtime_error( + "update_summary_statistics() not implemented for this hierarchy"); } } @@ -269,7 +317,8 @@ class AbstractHierarchy { "Cannot call update_summary_statistics() from a dependent " "hierarchy"); } else { - throw std::runtime_error("update_summary_statistics() not implemented"); + throw std::runtime_error( + "update_summary_statistics() not implemented for this hierarchy"); } } }; diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index c4c4f8953..86ea08a0a 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -3,16 +3,17 @@ #include -#include #include #include #include -#include +#include #include "abstract_hierarchy.h" #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" +#include "src/utils/covariates_getter.h" #include "src/utils/rng.h" +#include "updaters/target_lpdf_unconstrained.h" //! Base template class for a hierarchy object. @@ -22,326 +23,358 @@ //! class for further information). It includes class members and some more //! functions which could not be implemented in the non-templatized abstract //! class. -//! See, for instance, `ConjugateHierarchy` and `NNIGHierarchy` to better -//! understand the CRTP patterns. +//! See, for instance, `NNIGHierarchy` to better understand the CRTP patterns. //! @tparam Derived Name of the implemented derived class -//! @tparam State Class name of the container for state values -//! @tparam Hyperparams Class name of the container for hyperprior parameters -//! @tparam Prior Class name of the container for prior parameters +//! @tparam Likelihood Class name of the likelihood model for the hierarchy +//! @tparam PriorModel Class name of the prior model for the hierarchy -template +template class BaseHierarchy : public AbstractHierarchy { + protected: + //! Container for the likelihood of the hierarchy + std::shared_ptr like = std::make_shared(); + + //! Container for the prior model of the hierarchy + std::shared_ptr prior = std::make_shared(); + + //! Container for the update algorithm + std::shared_ptr updater; + public: - BaseHierarchy() = default; + // Useful type aliases + using HyperParams = decltype(prior->get_hypers()); + using ProtoHypersPtr = AbstractUpdater::ProtoHypersPtr; + using ProtoHypers = AbstractUpdater::ProtoHypers; + + //! Constructor that allows the specification of Likelihood, PriorModel and + //! Updater for a given Hierarchy + BaseHierarchy(std::shared_ptr like_ = nullptr, + std::shared_ptr prior_ = nullptr, + std::shared_ptr updater_ = nullptr) { + if (like_) { + set_likelihood(like_); + } + if (prior_) { + set_prior(prior_); + } + if (updater_) { + set_updater(updater_); + } else { + static_cast(this)->set_default_updater(); + } + } + + //! Default destructor ~BaseHierarchy() = default; + //! Sets the likelihood for the current hierarchy + void set_likelihood(std::shared_ptr like_) /*override*/ { + like = std::static_pointer_cast(like_); + } + + //! Sets the prior model for the current hierarchy + void set_prior(std::shared_ptr prior_) /*override*/ { + prior = std::static_pointer_cast(prior_); + } + + //! Sets the update algorithm for the current hierarchy + void set_updater(std::shared_ptr updater_) override { + updater = updater_; + }; + + //! Returns (a pointer to) the likelihood for the current hierarchy + std::shared_ptr get_likelihood() override { + return like; + } + + //! Returns (a pointer to) the prior model for the current hierarchy. + std::shared_ptr get_prior() override { return prior; } + //! Returns an independent, data-less copy of this object std::shared_ptr clone() const override { + // Create copy of the hierarchy auto out = std::make_shared(static_cast(*this)); - out->clear_data(); - out->clear_summary_statistics(); + // Cloning each component class + out->set_likelihood(std::static_pointer_cast(like->clone())); + out->set_prior(std::static_pointer_cast(prior->clone())); return out; - } + }; - //! Returns an independent, data-less copy of this object + //! Returns an independent, data-less deep copy of this object std::shared_ptr deep_clone() const override { + // Create copy of the hierarchy auto out = std::make_shared(static_cast(*this)); + // Simple clone for Likelihood is enough + out->set_likelihood(std::static_pointer_cast(like->clone())); + // Deep clone required for PriorModel + out->set_prior(std::static_pointer_cast(prior->deep_clone())); + return out; + } - out->clear_data(); - out->clear_summary_statistics(); + //! Public wrapper for `like_lpdf()` methods + double get_like_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)) const override { + return like->lpdf(datum, covariate); + } - out->create_empty_prior(); - std::shared_ptr new_prior(prior->New()); - new_prior->CopyFrom(*prior.get()); - out->get_mutable_prior()->CopyFrom(*new_prior.get()); + //! Evaluates the log-likelihood of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = + Eigen::MatrixXd(0, 0)) const override { + return like->lpdf_grid(data, covariates); + }; + + //! Public wrapper for `marg_lpdf()` methods + double get_marg_lpdf( + ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { + if (this->is_dependent() and covariate.size() != 0) { + return marg_lpdf(hier_params, datum, covariate); + } else { + return marg_lpdf(hier_params, datum); + } + } - out->create_empty_hypers(); - auto curr_hypers_proto = get_hypers_proto(); - out->set_hypers_from_proto(*curr_hypers_proto.get()); - out->initialize(); - return out; + //! Evaluates the log-prior predictive distribution of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate (Optional) covariate vector associated to datum + //! @return The evaluation of the lpdf + double prior_pred_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)) const override { + return get_marg_lpdf(prior->get_hypers_proto(), datum, covariate); } - //! Evaluates the log-likelihood of data in a grid of points + //! Evaluates the log-prior predictive distr. of data in a grid of points //! @param data Grid of points (by row) which are to be evaluated //! @param covariates (Optional) covariate vectors associated to data //! @return The evaluation of the lpdf - virtual Eigen::VectorXd like_lpdf_grid( + Eigen::VectorXd prior_pred_lpdf_grid( const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) + const override { + Eigen::VectorXd lpdf(data.rows()); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), covariates.row(i)); + } + } + return lpdf; + } - //! Generates new state values from the centering prior distribution - void sample_prior() override { - state = static_cast(this)->draw(*hypers); + //! Evaluates the log-conditional predictive distr. of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate (Optional) covariate vector associated to datum + //! @return The evaluation of the lpdf + double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)) const override { + return get_marg_lpdf(updater->compute_posterior_hypers(*like, *prior), + datum, covariate); + } + + //! Evaluates the log-prior predictive distr. of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + Eigen::VectorXd conditional_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) + const override { + Eigen::VectorXd lpdf(data.rows()); + covariates_getter cov_getter(covariates); + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->conditional_pred_lpdf( + data.row(i), cov_getter(i)); + } + return lpdf; } + //! Generates new state values from the centering prior distribution + void sample_prior() override { like->set_state(prior->sample(), false); }; + + //! Generates new state values from the centering posterior distribution + //! @param update_params Save posterior hypers after the computation? + void sample_full_cond(bool update_params = false) override { + updater->draw(*like, *prior, update_params); + }; + //! Overloaded version of sample_full_cond(bool), mainly used for debugging - virtual void sample_full_cond( + void sample_full_cond( const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override; + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override { + like->clear_data(); + like->clear_summary_statistics(); + + covariates_getter cov_getter(covariates); + for (int i = 0; i < data.rows(); i++) { + static_cast(this)->add_datum(i, data.row(i), false, + cov_getter(i)); + } + static_cast(this)->sample_full_cond(true); + }; + + //! Updates hyperparameter values given a vector of cluster states + void update_hypers(const std::vector + &states) override { + prior->update_hypers(states); + }; + + //! Returns the class of the current state + auto get_state() const -> decltype(like->get_state()) { + return like->get_state(); + }; //! Returns the current cardinality of the cluster - int get_card() const override { return card; } + int get_card() const override { return like->get_card(); }; //! Returns the logarithm of the current cardinality of the cluster - double get_log_card() const override { return log_card; } + double get_log_card() const override { return like->get_log_card(); }; //! Returns the indexes of data points belonging to this cluster - std::set get_data_idx() const override { return cluster_data_idx; } + std::set get_data_idx() const override { return like->get_data_idx(); }; + + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type + std::shared_ptr get_state_proto() + const override { + return std::make_shared( + like->get_state().get_as_proto()); + } //! Returns a pointer to the Protobuf message of the prior of this cluster google::protobuf::Message *get_mutable_prior() override { - if (prior == nullptr) { - create_empty_prior(); - } - return prior.get(); - } + return prior->get_mutable_prior(); + }; //! Writes current state to a Protobuf message by pointer - void write_state_to_proto( - google::protobuf::Message *const out) const override; + void write_state_to_proto(google::protobuf::Message *out) const override { + like->write_state_to_proto(out); + }; //! Writes current values of the hyperparameters to a Protobuf message by //! pointer - void write_hypers_to_proto( - google::protobuf::Message *const out) const override; + void write_hypers_to_proto(google::protobuf::Message *out) const override { + prior->write_hypers_to_proto(out); + }; - //! Returns the struct of the current state - State get_state() const { return state; } + //! Read and set state values from a given Protobuf message + void set_state_from_proto(const google::protobuf::Message &state_) override { + like->set_state_from_proto(state_); + }; - //! Returns the struct of the current prior hyperparameters - Hyperparams get_hypers() const { return *hypers; } - - //! Returns the struct of the current posterior hyperparameters - Hyperparams get_posterior_hypers() const { return posterior_hypers; } + //! Read and set hyperparameter values from a given Protobuf message + void set_hypers_from_proto( + const google::protobuf::Message &state_) override { + prior->set_hypers_from_proto(state_); + }; //! Adds a datum and its index to the hierarchy void add_datum( const int id, const Eigen::RowVectorXd &datum, const bool update_params = false, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override { + like->add_datum(id, datum, covariate); + if (update_params) { + updater->save_posterior_hypers( + updater->compute_posterior_hypers(*like, *prior)); + } + }; //! Removes a datum and its index from the hierarchy void remove_datum( const int id, const Eigen::RowVectorXd &datum, const bool update_params = false, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override { + like->remove_datum(id, datum, covariate); + if (update_params) { + updater->save_posterior_hypers( + updater->compute_posterior_hypers(*like, *prior)); + } + }; //! Main function that initializes members to appropriate values void initialize() override { - hypers = std::make_shared(); - check_prior_is_set(); - initialize_hypers(); + prior->initialize(); + if (is_conjugate()) { + updater->save_posterior_hypers(prior->get_hypers_proto()); + } initialize_state(); - posterior_hypers = *hypers; - clear_data(); - clear_summary_statistics(); - } - - //! Sets the (pointer to the) dataset matrix - void set_dataset(const Eigen::MatrixXd *const dataset) override { - dataset_ptr = dataset; - } + like->clear_data(); + like->clear_summary_statistics(); + }; - protected: - //! Raises an error if the prior pointer is not initialized - void check_prior_is_set() const { - if (prior == nullptr) { - throw std::invalid_argument("Hierarchy prior was not provided"); - } - } + //! Returns whether the hierarchy models multivariate data or not + bool is_multivariate() const override { return like->is_multivariate(); }; - //! Re-initializes the prior of the hierarchy to a newly created object - void create_empty_prior() { prior.reset(new Prior); } + //! Returns whether the hierarchy depends on covariate values or not + bool is_dependent() const override { return like->is_dependent(); }; - //! Re-initializes the hypers of the hierarchy to a newly created object - void create_empty_hypers() { hypers.reset(new Hyperparams); } + //! Returns whether the hierarchy represents a conjugate model or not + bool is_conjugate() const override { return updater->is_conjugate(); }; - //! Sets the cardinality of the cluster - void set_card(const int card_) { - card = card_; - log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); + //! Sets the (pointer to the) dataset matrix + void set_dataset(const Eigen::MatrixXd *const dataset) override { + like->set_dataset(dataset); } + protected: //! Initializes state parameters to appropriate values virtual void initialize_state() = 0; - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - virtual std::shared_ptr - get_hypers_proto() const = 0; - - //! Initializes hierarchy hyperparameters to appropriate values - virtual void initialize_hypers() = 0; - - //! Resets cardinality and indexes of data in this cluster - void clear_data() { - set_card(0); - cluster_data_idx = std::set(); - } - - virtual void clear_summary_statistics() = 0; - - //! Down-casts the given generic proto message to a ClusterState proto - bayesmix::AlgorithmState::ClusterState *downcast_state( - google::protobuf::Message *const state_) const { - return google::protobuf::internal::down_cast< - bayesmix::AlgorithmState::ClusterState *>(state_); - } - - //! Down-casts the given generic proto message to a ClusterState proto - const bayesmix::AlgorithmState::ClusterState &downcast_state( - const google::protobuf::Message &state_) const { - return google::protobuf::internal::down_cast< - const bayesmix::AlgorithmState::ClusterState &>(state_); - } - - //! Down-casts the given generic proto message to a HierarchyHypers proto - bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( - google::protobuf::Message *const state_) const { - return google::protobuf::internal::down_cast< - bayesmix::AlgorithmState::HierarchyHypers *>(state_); - } - - //! Down-casts the given generic proto message to a HierarchyHypers proto - const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( - const google::protobuf::Message &state_) const { - return google::protobuf::internal::down_cast< - const bayesmix::AlgorithmState::HierarchyHypers &>(state_); - } - - //! Container for state values - State state; - - //! Container for prior hyperparameters values - std::shared_ptr hypers; - - //! Container for posterior hyperparameters values - Hyperparams posterior_hypers; - - //! Pointer to a Protobuf prior object for this class - std::shared_ptr prior; - - //! Set of indexes of data points belonging to this cluster - std::set cluster_data_idx; - - //! Current cardinality of this cluster - int card = 0; - - //! Logarithm of current cardinality of this cluster - double log_card = stan::math::NEGATIVE_INFTY; - - //! Pointer to the dataset matrix for the mixture model - const Eigen::MatrixXd *dataset_ptr = nullptr; -}; - -template -void BaseHierarchy::add_datum( - const int id, const Eigen::RowVectorXd &datum, - const bool update_params /*= false*/, - const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) { - assert(cluster_data_idx.find(id) == cluster_data_idx.end()); - card += 1; - log_card = std::log(card); - static_cast(this)->update_ss(datum, covariate, true); - cluster_data_idx.insert(id); - if (update_params) { - static_cast(this)->save_posterior_hypers(); - } -} - -template -void BaseHierarchy::remove_datum( - const int id, const Eigen::RowVectorXd &datum, - const bool update_params /*= false*/, - const Eigen::RowVectorXd &covariate /* = Eigen::RowVectorXd(0)*/) { - static_cast(this)->update_ss(datum, covariate, false); - set_card(card - 1); - auto it = cluster_data_idx.find(id); - assert(it != cluster_data_idx.end()); - cluster_data_idx.erase(it); - if (update_params) { - static_cast(this)->save_posterior_hypers(); - } -} - -template -void BaseHierarchy::write_state_to_proto( - google::protobuf::Message *const out) const { - std::shared_ptr state_ = - get_state_proto(); - auto *out_cast = downcast_state(out); - out_cast->CopyFrom(*state_.get()); - out_cast->set_cardinality(card); -} - -template -void BaseHierarchy::write_hypers_to_proto( - google::protobuf::Message *const out) const { - std::shared_ptr hypers_ = - get_hypers_proto(); - auto *out_cast = downcast_hypers(out); - out_cast->CopyFrom(*hypers_.get()); -} - -template -Eigen::VectorXd -BaseHierarchy::like_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), covariates.row(i)); + //! Evaluates the log-marginal distribution of data in a single point + //! @param hier_params Pointer to the container (prior or posterior) + //! hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + virtual double marg_lpdf(ProtoHypersPtr hier_params, + const Eigen::RowVectorXd &datum) const { + if (!is_conjugate()) { + throw std::runtime_error( + "Call marg_lpdf() for a non-conjugate hierarchy"); + } else { + throw std::runtime_error( + "marg_lpdf() not implemented for this hierarchy"); } } - return lpdf; -} - -template -void BaseHierarchy::sample_full_cond( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { - clear_data(); - clear_summary_statistics(); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - covariates.row(i)); + + //! Evaluates the log-marginal distribution of data in a single point + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf + virtual double marg_lpdf(ProtoHypersPtr hier_params, + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + if (!is_conjugate()) { + throw std::runtime_error( + "Call marg_lpdf() for a non-conjugate hierarchy"); + } else { + throw std::runtime_error( + "marg_lpdf() not implemented for this hierarchy"); } } - static_cast(this)->sample_full_cond(true); -} +}; #endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/conjugate_hierarchy.h b/src/hierarchies/conjugate_hierarchy.h deleted file mode 100644 index 4d2430bea..000000000 --- a/src/hierarchies/conjugate_hierarchy.h +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ - -#include "base_hierarchy.h" - -//! Template base class for conjugate hierarchy objects. - -//! This class acts as the base class for conjugate models, i.e. ones for which -//! both the prior and posterior distribution have the same form -//! (non-conjugate hierarchies should instead inherit directly from -//! `BaseHierarchy`). This also means that the marginal distribution for the -//! data is available in closed form. For this reason, each class deriving from -//! this one must have a free method with one of the following signatures, -//! based on whether it depends on covariates or not: -//! double marg_lpdf( -//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, -//! const Eigen::RowVectorXd &covariate) const; -//! or -//! double marg_lpdf( -//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum) const; -//! This returns the evaluation of the marginal distribution on the given data -//! point (and covariate, if any), conditioned on the provided `Hyperparams` -//! object. The latter may contain either prior or posterior values for -//! hyperparameters, depending on where this function is called within the -//! library. -//! For more information, please refer to parent classes `AbstractHierarchy` -//! and `BaseHierarchy`. - -template -class ConjugateHierarchy - : public BaseHierarchy { - public: - using BaseHierarchy::hypers; - using BaseHierarchy::posterior_hypers; - using BaseHierarchy::state; - - ConjugateHierarchy() = default; - ~ConjugateHierarchy() = default; - - //! Public wrapper for `marg_lpdf()` methods - virtual double get_marg_lpdf( - const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const; - - //! Evaluates the log-prior predictive distribution of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate (Optional) covariate vector associated to datum - //! @return The evaluation of the lpdf - double prior_pred_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = - Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(*hypers, datum, covariate); - } - - //! Evaluates the log-conditional predictive distr. of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate (Optional) covariate vector associated to datum - //! @return The evaluation of the lpdf - double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = - Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(posterior_hypers, datum, covariate); - } - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - virtual Eigen::VectorXd prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - virtual Eigen::VectorXd conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(const bool update_params = true) override { - if (this->card == 0) { - // No posterior update possible - static_cast(this)->sample_prior(); - } else { - Hyperparams params = - update_params - ? static_cast(this)->compute_posterior_hypers() - : posterior_hypers; - state = static_cast(this)->draw(params); - } - } - - //! Saves posterior hyperparameters to the corresponding class member - void save_posterior_hypers() { - posterior_hypers = - static_cast(this)->compute_posterior_hypers(); - } - - //! Returns whether the hierarchy represents a conjugate model or not - bool is_conjugate() const override { return true; } - - protected: - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - virtual double marg_lpdf(const Hyperparams ¶ms, - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - if (!this->is_dependent()) { - throw std::runtime_error( - "Cannot call marg_lpdf() from a non-dependent hierarchy"); - } else { - throw std::runtime_error("marg_lpdf() not implemented"); - } - } - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - virtual double marg_lpdf(const Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - if (this->is_dependent()) { - throw std::runtime_error( - "Cannot call marg_lpdf() from a dependent hierarchy"); - } else { - throw std::runtime_error("marg_lpdf() not implemented"); - } - } -}; - -template -double ConjugateHierarchy::get_marg_lpdf( - const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { - if (this->is_dependent()) { - return marg_lpdf(params, datum, covariate); - } else { - return marg_lpdf(params, datum); - } -} - -template -Eigen::VectorXd -ConjugateHierarchy::prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), covariates.row(i)); - } - } - return lpdf; -} - -template -Eigen::VectorXd ConjugateHierarchy:: - conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), covariates.row(i)); - } - } - return lpdf; -} - -#endif // BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ diff --git a/src/hierarchies/fa_hierarchy.cc b/src/hierarchies/fa_hierarchy.cc deleted file mode 100644 index 38a76c79d..000000000 --- a/src/hierarchies/fa_hierarchy.cc +++ /dev/null @@ -1,300 +0,0 @@ -#include "fa_hierarchy.h" - -#include - -#include -#include -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/proto_utils.h" -#include "src/utils/rng.h" - -double FAHierarchy::like_lpdf(const Eigen::RowVectorXd& datum) const { - return bayesmix::multi_normal_lpdf_woodbury_chol( - datum, state.mu, state.psi_inverse, state.cov_wood, state.cov_logdet); -} - -FA::State FAHierarchy::draw(const FA::Hyperparams& params) { - auto& rng = bayesmix::Rng::Instance().get(); - FA::State out; - out.mu = params.mutilde; - out.psi = params.beta / (params.alpha0 + 1.); - out.eta = Eigen::MatrixXd::Zero(card, params.q); - out.lambda = Eigen::MatrixXd::Zero(dim, params.q); - - for (size_t j = 0; j < dim; j++) { - out.mu[j] = - stan::math::normal_rng(params.mutilde[j], sqrt(params.phi), rng); - - out.psi[j] = stan::math::inv_gamma_rng(params.alpha0, params.beta[j], rng); - - for (size_t i = 0; i < params.q; i++) { - out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); - } - } - - for (size_t i = 0; i < card; i++) { - for (size_t j = 0; j < params.q; j++) { - out.eta(i, j) = stan::math::normal_rng(0, 1, rng); - } - } - - out.psi_inverse = out.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(out.cov_wood, out.cov_logdet, out.lambda, - out.psi_inverse); - - return out; -} - -void FAHierarchy::initialize_state() { - state.mu = hypers->mutilde; - state.psi = hypers->beta / (hypers->alpha0 + 1.); - state.eta = Eigen::MatrixXd::Zero(card, hypers->q); - state.lambda = Eigen::MatrixXd::Zero(dim, hypers->q); - state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, - state.psi_inverse); -} - -void FAHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mutilde = bayesmix::to_eigen(prior->fixed_values().mutilde()); - dim = hypers->mutilde.size(); - hypers->beta = bayesmix::to_eigen(prior->fixed_values().beta()); - hypers->phi = prior->fixed_values().phi(); - hypers->alpha0 = prior->fixed_values().alpha0(); - hypers->q = prior->fixed_values().q(); - - // Automatic initialization - if (dim == 0) { - hypers->mutilde = dataset_ptr->colwise().mean(); - dim = hypers->mutilde.size(); - } - if (hypers->beta.size() == 0) { - Eigen::MatrixXd centered = - dataset_ptr->rowwise() - dataset_ptr->colwise().mean(); - auto cov_llt = ((centered.transpose() * centered) / - double(dataset_ptr->rows() - 1.)) - .llt(); - Eigen::MatrixXd precision_matrix( - cov_llt.solve(Eigen::MatrixXd::Identity(dim, dim))); - hypers->beta = - (hypers->alpha0 - 1) * precision_matrix.diagonal().cwiseInverse(); - if (hypers->alpha0 == 1) { - throw std::invalid_argument( - "Scale parameter must be different than 1 when automatic " - "initialization is used"); - } - } - // Check validity - if (dim != hypers->beta.rows()) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - for (size_t j = 0; j < dim; j++) { - if (hypers->beta[j] <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - } - if (hypers->alpha0 <= 0) { - throw std::invalid_argument("Scale parameter must be > 0"); - } - if (hypers->phi <= 0) { - throw std::invalid_argument("Diffusion parameter must be > 0"); - } - if (hypers->q <= 0) { - throw std::invalid_argument("Number of factors must be > 0"); - } - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void FAHierarchy::update_hypers( - const std::vector& states) { - auto& rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void FAHierarchy::update_summary_statistics(const Eigen::RowVectorXd& datum, - const bool add) { - if (add) { - data_sum += datum; - } else { - data_sum -= datum; - } -} - -void FAHierarchy::clear_summary_statistics() { - data_sum = Eigen::VectorXd::Zero(dim); -} - -void FAHierarchy::set_state_from_proto( - const google::protobuf::Message& state_) { - auto& statecast = downcast_state(state_); - state.mu = bayesmix::to_eigen(statecast.fa_state().mu()); - state.psi = bayesmix::to_eigen(statecast.fa_state().psi()); - state.eta = bayesmix::to_eigen(statecast.fa_state().eta()); - state.lambda = bayesmix::to_eigen(statecast.fa_state().lambda()); - state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, - state.psi_inverse); - set_card(statecast.cardinality()); -} - -std::shared_ptr -FAHierarchy::get_state_proto() const { - bayesmix::FAState state_; - bayesmix::to_proto(state.mu, state_.mutable_mu()); - bayesmix::to_proto(state.psi, state_.mutable_psi()); - bayesmix::to_proto(state.eta, state_.mutable_eta()); - bayesmix::to_proto(state.lambda, state_.mutable_lambda()); - - auto out = std::make_shared(); - out->mutable_fa_state()->CopyFrom(state_); - return out; -} - -void FAHierarchy::set_hypers_from_proto( - const google::protobuf::Message& hypers_) { - auto& hyperscast = downcast_hypers(hypers_).fa_state(); - hypers->mutilde = bayesmix::to_eigen(hyperscast.mutilde()); - hypers->alpha0 = hyperscast.alpha0(); - hypers->beta = bayesmix::to_eigen(hyperscast.beta()); - hypers->phi = hyperscast.phi(); - hypers->q = hyperscast.q(); -} - -std::shared_ptr -FAHierarchy::get_hypers_proto() const { - bayesmix::FAPriorDistribution hypers_; - bayesmix::to_proto(hypers->mutilde, hypers_.mutable_mutilde()); - bayesmix::to_proto(hypers->beta, hypers_.mutable_beta()); - hypers_.set_alpha0(hypers->alpha0); - hypers_.set_phi(hypers->phi); - hypers_.set_q(hypers->q); - - auto out = std::make_shared(); - out->mutable_fa_state()->CopyFrom(hypers_); - return out; -} - -void FAHierarchy::sample_full_cond(const bool update_params /*= false*/) { - if (this->card == 0) { - // No posterior update possible - sample_prior(); - } else { - sample_eta(); - sample_mu(); - sample_psi(); - sample_lambda(); - } -} - -void FAHierarchy::sample_eta() { - auto& rng = bayesmix::Rng::Instance().get(); - auto sigma_eta_inv_llt = - (Eigen::MatrixXd::Identity(hypers->q, hypers->q) + - state.lambda.transpose() * state.psi_inverse * state.lambda) - .llt(); - if (state.eta.rows() != card) { - state.eta = Eigen::MatrixXd::Zero(card, state.eta.cols()); - } - Eigen::MatrixXd temp_product( - sigma_eta_inv_llt.solve(state.lambda.transpose() * state.psi_inverse)); - auto iterator = cluster_data_idx.begin(); - for (size_t i = 0; i < card; i++, iterator++) { - Eigen::VectorXd tempvector(dataset_ptr->row( - *iterator)); // TODO use slicing when Eigen is updated to v3.4 - state.eta.row(i) = (bayesmix::multi_normal_prec_chol_rng( - temp_product * (tempvector - state.mu), sigma_eta_inv_llt, rng)); - } -} - -void FAHierarchy::sample_mu() { - auto& rng = bayesmix::Rng::Instance().get(); - Eigen::DiagonalMatrix sigma_mu; - - sigma_mu.diagonal() = - (card * state.psi_inverse.diagonal().array() + hypers->phi) - .cwiseInverse(); - - Eigen::VectorXd sum = (state.eta.colwise().sum()); - - Eigen::VectorXd mumean = - sigma_mu * (hypers->phi * hypers->mutilde + - state.psi_inverse * (data_sum - state.lambda * sum)); - - state.mu = bayesmix::multi_normal_diag_rng(mumean, sigma_mu, rng); -} - -void FAHierarchy::sample_lambda() { - auto& rng = bayesmix::Rng::Instance().get(); - - Eigen::MatrixXd temp_etateta(state.eta.transpose() * state.eta); - - for (size_t j = 0; j < dim; j++) { - auto sigma_lambda_inv_llt = - (Eigen::MatrixXd::Identity(hypers->q, hypers->q) + - temp_etateta / state.psi[j]) - .llt(); - Eigen::VectorXd tempsum(card); - const Eigen::VectorXd& data_col = dataset_ptr->col(j); - auto iterator = cluster_data_idx.begin(); - for (size_t i = 0; i < card; i++, iterator++) { - tempsum[i] = data_col( - *iterator); // TODO use slicing when Eigen is updated to v3.4 - } - tempsum = tempsum.array() - state.mu[j]; - tempsum = tempsum.array() / state.psi[j]; - - state.lambda.row(j) = bayesmix::multi_normal_prec_chol_rng( - sigma_lambda_inv_llt.solve(state.eta.transpose() * tempsum), - sigma_lambda_inv_llt, rng); - } -} - -void FAHierarchy::sample_psi() { - auto& rng = bayesmix::Rng::Instance().get(); - - for (size_t j = 0; j < dim; j++) { - double sum = 0; - auto iterator = cluster_data_idx.begin(); - for (size_t i = 0; i < card; i++, iterator++) { - sum += std::pow( - ((*dataset_ptr)(*iterator, j) - - state.mu[j] - // TODO use slicing when Eigen is updated to v3.4 - state.lambda.row(j).dot(state.eta.row(i))), - 2); - } - state.psi[j] = stan::math::inv_gamma_rng(hypers->alpha0 + card / 2, - hypers->beta[j] + sum / 2, rng); - } - state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, - state.psi_inverse); -} - -void FAHierarchy::compute_wood_factors( - Eigen::MatrixXd& cov_wood, double& cov_logdet, - const Eigen::MatrixXd& lambda, - const Eigen::DiagonalMatrix& psi_inverse) { - auto [cov_wood_, cov_logdet_] = - bayesmix::compute_wood_chol_and_logdet(psi_inverse, lambda); - cov_logdet = cov_logdet_; - cov_wood = cov_wood_; -} diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index 8b6da31d4..fc7ce3079 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -1,134 +1,111 @@ #ifndef BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" #include "base_hierarchy.h" #include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" +#include "likelihoods/fa_likelihood.h" +#include "priors/fa_prior_model.h" #include "src/utils/distributions.h" - -//! Mixture of Factor Analysers hierarchy for multivariate data. - -//! This class represents a hierarchical model where data are distributed - -namespace FA { -//! Custom container for State values -struct State { - Eigen::VectorXd mu, psi; - Eigen::MatrixXd eta, lambda, cov_wood; - Eigen::DiagonalMatrix psi_inverse; - double cov_logdet; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - Eigen::VectorXd mutilde, beta; - double phi, alpha0; - unsigned int q; -}; - -}; // namespace FA - -class FAHierarchy : public BaseHierarchy { +#include "updaters/fa_updater.h" + +/** + * Mixture of Factor Analysers hierarchy for multivariate data. + * + * This class represents a hierarchical model where data are distributed + * according to a multivariate Normal likelihood with a specific factorization + * of the covariance matrix (see the `FAHierarchy` class for details). The + * likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma + * centering distribution (see the `FAPriorModel` class for details). That is: + * + * \f[ + * f(x_i \mid \mu, \Sigma, \Lambda) &= N(\mu, \Sigma + \Lambda \Lambda^T) \\ + * \mu &\sim N_p(\tilde \mu, \psi I) \\ + * \Lambda &\sim DL(\alpha) \\ + * \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ + * \sigma^2_j &\sim IG(a,b) \quad j=1,...,p + * \f] + * + * where Lambda is the latent score matrix (size \f$ p \times d \f$ + * with \f$ d << p \f$) and \f$ DL(\alpha) \f$ is the Laplace-Dirichlet + * distribution. See Bhattacharya et al. (2015) for further details + */ + +class FAHierarchy + : public BaseHierarchy { public: FAHierarchy() = default; ~FAHierarchy() = default; - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector& - states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - FA::State draw(const FA::Hyperparams& params); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::FA; } - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message& state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message& hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return true; } - - //! Saves posterior hyperparameters to the corresponding class member - void save_posterior_hypers() { - // No Hyperprior present in the hierarchy - } - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(const bool update_params = false) override; - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd& datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd& datum, - const bool add) override; + //! Sets the default updater algorithm for this hierarchy + void set_default_updater() { updater = std::make_shared(); } //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Gibbs sampling step for state variable eta - void sample_eta(); - - //! Gibbs sampling step for state variable mu - void sample_mu(); - - //! Gibbs sampling step for state variable psi - void sample_psi(); - - //! Gibbs sampling step for state variable lambda - void sample_lambda(); - - //! Helper function to compute factors needed for likelihood evaluation - void compute_wood_factors( - Eigen::MatrixXd& cov_wood, double& cov_logdet, - const Eigen::MatrixXd& lambda, - const Eigen::DiagonalMatrix& psi_inverse); + void initialize_state() override { + // Initialize likelihood dimension to prior one + like->set_dim(prior->get_dim()); + // Get hypers and data dimension + auto hypers = prior->get_hypers(); + unsigned int dim = like->get_dim(); + // Initialize likelihood state + State::FA state; + state.mu = hypers.mutilde; + state.psi = hypers.beta / (hypers.alpha0 + 1.); + state.lambda = Eigen::MatrixXd::Zero(dim, hypers.q); + state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); + state.compute_wood_factors(); + like->set_state(state); + } +}; - //! Sum of data points currently belonging to the cluster - Eigen::VectorXd data_sum; +//! Empirical-Bayes hyperparameters initialization for the FAHierarchy. +//! Sets the hyperparameters in `hier` starting from the data on which the user +//! wants to fit the model. +inline void set_fa_hyperparams_from_data(FAHierarchy* hier) { + // TODO test this function + auto dataset_ptr = + std::static_pointer_cast(hier->get_likelihood()) + ->get_dataset(); + auto hypers = + std::static_pointer_cast(hier->get_prior())->get_hypers(); + unsigned int dim = + std::static_pointer_cast(hier->get_likelihood()) + ->get_dim(); + + // Automatic initialization + if (dim == 0) { + hypers.mutilde = dataset_ptr->colwise().mean(); + dim = hypers.mutilde.size(); + } + if (hypers.beta.size() == 0) { + Eigen::MatrixXd centered = + dataset_ptr->rowwise() - dataset_ptr->colwise().mean(); + auto cov_llt = + ((centered.transpose() * centered) / double(dataset_ptr->rows() - 1.)) + .llt(); + Eigen::MatrixXd precision_matrix( + cov_llt.solve(Eigen::MatrixXd::Identity(dim, dim))); + hypers.beta = + (hypers.alpha0 - 1) * precision_matrix.diagonal().cwiseInverse(); + if (hypers.alpha0 == 1) { + throw std::invalid_argument( + "Scale parameter must be different than 1 when automatic " + "initialization is used"); + } + } - //! Number of variables for each datum - size_t dim; + bayesmix::AlgorithmState::HierarchyHypers state; + bayesmix::to_proto(hypers.mutilde, + state.mutable_fa_state()->mutable_mutilde()); + bayesmix::to_proto(hypers.beta, state.mutable_fa_state()->mutable_beta()); + state.mutable_fa_state()->set_alpha0(hypers.alpha0); + state.mutable_fa_state()->set_phi(hypers.phi); + state.mutable_fa_state()->set_q(hypers.q); + hier->get_prior()->set_hypers_from_proto(state); }; #endif // BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ diff --git a/src/hierarchies/lapnig_hierarchy.cc b/src/hierarchies/lapnig_hierarchy.cc deleted file mode 100644 index b0d479244..000000000 --- a/src/hierarchies/lapnig_hierarchy.cc +++ /dev/null @@ -1,215 +0,0 @@ -#include "lapnig_hierarchy.h" - -#include - -#include -#include -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/rng.h" - -unsigned int LapNIGHierarchy::accepted_ = 0; -unsigned int LapNIGHierarchy::iter_ = 0; - -void LapNIGHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.scale = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -void LapNIGHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).lapnig_state(); - hypers->mean = hyperscast.mean(); - hypers->var = hyperscast.var(); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); - hypers->mh_mean_var = hyperscast.mh_mean_var(); - hypers->mh_log_scale_var = hyperscast.mh_log_scale_var(); -} - -double LapNIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return stan::math::double_exponential_lpdf(datum(0), state.mean, - state.scale); -} - -std::shared_ptr -LapNIGHierarchy::get_state_proto() const { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.scale); - - auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); - return out; -} - -std::shared_ptr -LapNIGHierarchy::get_hypers_proto() const { - bayesmix::LapNIGState hypers_; - hypers_.set_mean(hypers->mean); - hypers_.set_var(hypers->var); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - hypers_.set_mh_mean_var(hypers->mh_mean_var); - hypers_.set_mh_log_scale_var(hypers->mh_log_scale_var); - - auto out = std::make_shared(); - out->mutable_lapnig_state()->CopyFrom(hypers_); - return out; -} - -void LapNIGHierarchy::clear_summary_statistics() { - cluster_data_values.clear(); - sum_abs_diff_curr = 0; - sum_abs_diff_prop = 0; -} - -void LapNIGHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = prior->fixed_values().mean(); - hypers->var = prior->fixed_values().var(); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - hypers->mh_mean_var = prior->fixed_values().mh_mean_var(); - hypers->mh_log_scale_var = prior->fixed_values().mh_log_scale_var(); - // Check validity - if (hypers->var <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("Scale parameter must be > 0"); - } - if (hypers->mh_mean_var <= 0) { - throw std::invalid_argument("mh_mean_var parameter must be > 0"); - } - if (hypers->mh_log_scale_var <= 0) { - throw std::invalid_argument("mh_log_scale_var parameter must be > 0"); - } - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void LapNIGHierarchy::initialize_state() { - state.mean = hypers->mean; - state.scale = hypers->scale / (hypers->shape + 1); // mode of Inv-Gamma -} - -void LapNIGHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -LapNIG::State LapNIGHierarchy::draw(const LapNIG::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - LapNIG::State out{}; - out.scale = stan::math::inv_gamma_rng(params.shape, 1. / params.scale, rng); - out.mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); - return out; -} - -void LapNIGHierarchy::update_summary_statistics( - const Eigen::RowVectorXd &datum, const bool add) { - if (add) { - sum_abs_diff_curr += std::abs(state.mean - datum(0, 0)); - cluster_data_values.push_back(datum); - } else { - sum_abs_diff_curr -= std::abs(state.mean - datum(0, 0)); - auto it = std::find(cluster_data_values.begin(), cluster_data_values.end(), - datum); - cluster_data_values.erase(it); - } -} - -void LapNIGHierarchy::sample_full_cond(const bool update_params /*= false*/) { - if (this->card == 0) { - // No posterior update possible - this->sample_prior(); - } else { - // Number of iterations to compute the acceptance rate - ++iter_; - - // Random generator - auto &rng = bayesmix::Rng::Instance().get(); - - // Candidate mean and candidate log_scale - Eigen::VectorXd curr_unc_params(2); - curr_unc_params << state.mean, std::log(state.scale); - - Eigen::VectorXd prop_unc_params = propose_rwmh(curr_unc_params); - - double log_target_prop = - eval_prior_lpdf_unconstrained(prop_unc_params) + - eval_like_lpdf_unconstrained(prop_unc_params, false); - - double log_target_curr = - eval_prior_lpdf_unconstrained(curr_unc_params) + - eval_like_lpdf_unconstrained(curr_unc_params, true); - - double log_a_rate = log_target_prop - log_target_curr; - - if (std::log(stan::math::uniform_rng(0, 1, rng)) < log_a_rate) { - ++accepted_; - state.mean = prop_unc_params(0); - state.scale = std::exp(prop_unc_params(1)); - sum_abs_diff_curr = sum_abs_diff_prop; - } - } -} - -Eigen::VectorXd LapNIGHierarchy::propose_rwmh( - const Eigen::VectorXd &curr_vals) { - auto &rng = bayesmix::Rng::Instance().get(); - double candidate_mean = - curr_vals(0) + stan::math::normal_rng(0, sqrt(hypers->mh_mean_var), rng); - double candidate_log_scale = - curr_vals(1) + - stan::math::normal_rng(0, sqrt(hypers->mh_log_scale_var), rng); - Eigen::VectorXd proposal(2); - proposal << candidate_mean, candidate_log_scale; - return proposal; -} - -double LapNIGHierarchy::eval_prior_lpdf_unconstrained( - const Eigen::VectorXd &unconstrained_parameters) { - double mu = unconstrained_parameters(0); - double log_scale = unconstrained_parameters(1); - double scale = std::exp(log_scale); - return stan::math::normal_lpdf(mu, hypers->mean, std::sqrt(hypers->var)) + - stan::math::inv_gamma_lpdf(scale, hypers->shape, hypers->scale) + - log_scale; -} - -double LapNIGHierarchy::eval_like_lpdf_unconstrained( - const Eigen::VectorXd &unconstrained_parameters, const bool is_current) { - double mean = unconstrained_parameters(0); - double log_scale = unconstrained_parameters(1); - double scale = std::exp(log_scale); - double diff_sum = 0; // Sum of absolute values of data - candidate_mean - if (is_current) { - diff_sum = sum_abs_diff_curr; - } else { - for (auto &elem : cluster_data_values) { - diff_sum += std::abs(elem(0, 0) - mean); - } - sum_abs_diff_prop = diff_sum; - } - return std::log(0.5 / scale) + (-0.5 / scale * diff_sum); -} diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h index cdb07e55b..01574da19 100644 --- a/src/hierarchies/lapnig_hierarchy.h +++ b/src/hierarchies/lapnig_hierarchy.h @@ -1,146 +1,57 @@ #ifndef BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ -#include - -#include -#include -#include -#include - -#include "algorithm_state.pb.h" #include "base_hierarchy.h" #include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Laplace Normal-InverseGamma hierarchy for univariate data. - -//! This class represents a hierarchical model where data are distributed -//! according to a laplace likelihood, the parameters of which have a -//! Normal-InverseGamma centering distribution. That is: -//! f(x_i|mu,lambda) = Laplace(mu,lambda) -//! (mu,lambda) ~ N-IG(mu0, lambda0, alpha0, beta0) -//! The state is composed of mean and scale. The state hyperparameters, -//! contained in the Hypers object, are (mu_0, lambda0, alpha0, beta0, -//! scale_var, mean_var), all scalar values. Note that this hierarchy is NOT -//! conjugate, thus the marginal distribution is not available in closed form. -//! The hyperprameters scale_var and mean_var are used to perform a step of -//! Random Walk Metropolis Hastings to sample from the full conditionals. For -//! more information, please refer to parent classes: `AbstractHierarchy` and -//! `BaseHierarchy`. - -namespace LapNIG { -//! Custom container for State values -struct State { - double mean, scale; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double mean, var, shape, scale, mh_log_scale_var, mh_mean_var; -}; -} // namespace LapNIG +#include "likelihoods/laplace_likelihood.h" +#include "priors/nxig_prior_model.h" +#include "updaters/mala_updater.h" + +/** + * Laplace Normal-InverseGamma hierarchy for univariate data. + * + * This class represents a hierarchical model where data are distributed + * according to a Laplace likelihood (see the `LaplaceLikelihood` class for + * deatils). The likelihood parameters have a Normal x InverseGamma centering + * distribution (see the `NxIGPriorModel` class for details). That is: + * + * \f[ + * f(x_i \mid \mu,\sigma^2) &= Laplace(\mu,\sqrt{\sigma^2/2})\\ + * \mu &\sim N(\mu_0,\eta^2) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * The state is composed of mean and variance (thus the scale for the Laplace + * distribution is \f$ \sqrt{\sigma^2/2}) \f$. The state hyperparameters are + * \f$(mu_0, \sigma^2, a, b)\f$, all scalar values. Note that this hierarchy + * is NOT conjugate, thus the marginal distribution is not available in closed + * form. + */ class LapNIGHierarchy - : public BaseHierarchy { + : public BaseHierarchy { public: LapNIGHierarchy() = default; ~LapNIGHierarchy() = default; - //! Counters for tracking acceptance rate in MH step - static unsigned int accepted_; - static unsigned int iter_; - //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::LapNIG; } - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - void save_posterior_hypers() { - throw std::runtime_error("save_posterior_hypers() not implemented"); - } - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - LapNIG::State draw(const LapNIG::Hyperparams ¶ms); - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(const bool update_params = false) override; - - protected: - //! Set of values of data points belonging to this cluster - std::list cluster_data_values; - - //! Sum of absolute differences for current params - double sum_abs_diff_curr = 0; - - //! Sum of absolute differences for proposal params - double sum_abs_diff_prop = 0; - - //! Samples from the proposal distribution using Random Walk - //! Metropolis-Hastings - Eigen::VectorXd propose_rwmh(const Eigen::VectorXd &curr_vals); - - //! Evaluates the prior given the mean (unconstrained_parameters(0)) - //! and log of the scale (unconstrained_parameters(1)) - double eval_prior_lpdf_unconstrained( - const Eigen::VectorXd &unconstrained_parameters); - - //! Evaluates the (sum of the) log likelihood for all the observations in the - //! cluster given the mean (unconstrained_parameters(0)) - //! and log of the scale (unconstrained_parameters(1)). - //! The parameter "is_current" is used to identify if the evaluation of the - //! likelihood is on the current or on the proposed parameters, in order to - //! avoid repeating calculations of the sum of the absolute differences - double eval_like_lpdf_unconstrained( - const Eigen::VectorXd &unconstrained_parameters, const bool is_current); - - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; + //! Sets the default updater algorithm for this hierarchy + void set_default_updater() { updater = std::make_shared(); } //! Initializes state parameters to appropriate values - void initialize_state() override; + void initialize_state() override { + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::UniLS state; + state.mean = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + like->set_state(state); + }; }; + #endif // BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt new file mode 100644 index 000000000..df10e8674 --- /dev/null +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -0,0 +1,17 @@ +target_sources(bayesmix PUBLIC + likelihood_internal.h + abstract_likelihood.h + base_likelihood.h + uni_norm_likelihood.h + uni_norm_likelihood.cc + multi_norm_likelihood.h + multi_norm_likelihood.cc + uni_lin_reg_likelihood.h + uni_lin_reg_likelihood.cc + laplace_likelihood.h + laplace_likelihood.cc + fa_likelihood.h + fa_likelihood.cc +) + +add_subdirectory(states) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h new file mode 100644 index 000000000..38425829a --- /dev/null +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -0,0 +1,190 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_ABSTRACT_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_ABSTRACT_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" + +//! Abstract class for a generic likelihood +//! +//! This class is the basis for a curiously recurring template pattern (CRTP) +//! for `Likelihood` objects, and is solely composed of interface functions for +//! derived classes to use. +//! +//! A likelihood can evaluate the log probability density faction (lpdf) at a +//! certain point given the current value of the parameters, or compute +//! directly the lpdf for the whole cluster. +//! +//! Whenever possible, we store in a `Likelihood` instance also the sufficient +//! statistics of the data allocated to the cluster, in order to speed-up +//! computations. + +class AbstractLikelihood { + public: + //! Default destructor + virtual ~AbstractLikelihood() = default; + + //! Returns an independent, data-less copy of this object + virtual std::shared_ptr clone() const = 0; + + //! Public wrapper for `compute_lpdf()` methods + double lpdf( + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { + if (is_dependent() and covariate.size() != 0) { + return compute_lpdf(datum, covariate); + } else { + return compute_lpdf(datum); + } + } + + //! Evaluates the log likelihood over all the data in the cluster + //! given unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood over all data in the cluster + virtual double cluster_lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) const { + throw std::runtime_error( + "cluster_lpdf_from_unconstrained() not implemented for this " + "likelihood"); + } + + //! This version using `stan::math::var` type is required for Stan automatic + //! differentiation. Evaluates the log likelihood over all the data in the + //! cluster given unconstrained parameter values. By unconstrained parameters + //! we mean that each entry of the parameter vector can range over (-inf, + //! inf). Usually, some kind of transformation is required from the + //! unconstrained parameterization to the actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood over all data in the cluster + virtual stan::math::var cluster_lpdf_from_unconstrained( + Eigen::Matrix unconstrained_params) + const { + throw std::runtime_error( + "cluster_lpdf_from_unconstrained() not implemented for this " + "likelihood"); + } + + //! Evaluates the log-likelihood of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + virtual Eigen::VectorXd lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const = 0; + + //! Returns whether the likelihood models multivariate data or not + virtual bool is_multivariate() const = 0; + + //! Returns whether the likelihood depends on covariate values or not + virtual bool is_dependent() const = 0; + + //! Read and set state values from a given Protobuf message + virtual void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) = 0; + + //! Read and set state values from the vector of unconstrained parameters + virtual void set_state_from_unconstrained( + const Eigen::VectorXd &unconstrained_state) = 0; + + //! Writes current state to a Protobuf message by pointer + virtual void write_state_to_proto(google::protobuf::Message *out) const = 0; + + //! Sets the (pointer to) the dataset in the cluster + virtual void set_dataset(const Eigen::MatrixXd *const dataset) = 0; + + //! Adds a datum and its index to the likelihood + virtual void add_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; + + //! Removes a datum and its index from the likelihood + virtual void remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; + + //! Public wrapper for `update_sum_stats()` methods + void update_summary_statistics(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) { + if (is_dependent()) { + return update_sum_stats(datum, covariate, add); + } else { + return update_sum_stats(datum, add); + } + } + + //! Resets the values of the summary statistics in the likelihood + virtual void clear_summary_statistics() = 0; + + //! Returns the vector of the unconstrained parameters for this likelihood + virtual Eigen::VectorXd get_unconstrained_state() = 0; + + protected: + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + virtual double compute_lpdf(const Eigen::RowVectorXd &datum) const { + if (is_dependent()) { + throw std::runtime_error( + "Cannot call compute_lpdf() from a dependent likelihood"); + } else { + throw std::runtime_error( + "compute_lpdf() not implemented for this likelihood"); + } + } + + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf + virtual double compute_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + if (!is_dependent()) { + throw std::runtime_error( + "Cannot call compute_lpdf() from a non-dependent likelihood"); + } else { + throw std::runtime_error( + "compute_lpdf() not implemented for this likelihood"); + } + } + + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param add Whether the datum is being added or removed + virtual void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) { + if (is_dependent()) { + throw std::runtime_error( + "Cannot call this function from a dependent hierarchy"); + } else { + throw std::runtime_error("Not implemented"); + } + } + + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param covariate Covariate vector associated to datum + //! @param add Whether the datum is being added or removed + virtual void update_sum_stats(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) { + if (!is_dependent()) { + throw std::runtime_error( + "Cannot call this function from a non-dependent hierarchy"); + } else { + throw std::runtime_error("Not implemented"); + } + } +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_ABSTRACT_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h new file mode 100644 index 000000000..40ca9dd1d --- /dev/null +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -0,0 +1,228 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ + +#include + +#include +#include +#include +#include + +#include "abstract_likelihood.h" +#include "algorithm_state.pb.h" +#include "likelihood_internal.h" +#include "src/utils/covariates_getter.h" + +//! Base template class of a `Likelihood` object +//! +//! This class derives from `AbstractLikelihood` and is templated over +//! `Derived` (needed for the curiously recurring template pattern) and +//! `State`: an instance of `BaseState` +//! +//! @tparam Derived Name of the implemented derived class +//! @tparam State Class name of the container for state values + +template +class BaseLikelihood : public AbstractLikelihood { + public: + //! Default constructor + BaseLikelihood() = default; + + //! Default destructor + ~BaseLikelihood() = default; + + //! Returns an independent, data-less copy of this object + std::shared_ptr clone() const override { + auto out = std::make_shared(static_cast(*this)); + out->clear_data(); + out->clear_summary_statistics(); + return out; + } + + //! Evaluates the log likelihood over all the data in the cluster + //! given unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parametrization to the actual one. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood over all data in the cluster + double cluster_lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) const override { + return internal::cluster_lpdf_from_unconstrained( + static_cast(*this), unconstrained_params, 0); + } + + //! This version using `stan::math::var` type is required for Stan automatic + //! differentiation. Evaluates the log likelihood over all the data in the + //! cluster given unconstrained parameter values. By unconstrained parameters + //! we mean that each entry of the parameter vector can range over (-inf, + //! inf). Usually, some kind of transformation is required from the + //! unconstrained parametrization to the actual one. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood over all data in the cluster + stan::math::var cluster_lpdf_from_unconstrained( + Eigen::Matrix unconstrained_params) + const override { + return internal::cluster_lpdf_from_unconstrained( + static_cast(*this), unconstrained_params, 0); + } + + //! Evaluates the log-likelihood of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + Eigen::VectorXd lpdf_grid(const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = + Eigen::MatrixXd(0, 0)) const override; + + //! Returns the current cardinality of the cluster + int get_card() const { return card; } + + //! Returns the logarithm of the current cardinality of the cluster + double get_log_card() const { return log_card; } + + //! Returns the indexes of data points belonging to this cluster + std::set get_data_idx() const { return cluster_data_idx; } + + //! Writes current state to a Protobuf message by pointer + void write_state_to_proto(google::protobuf::Message *out) const override; + + //! Returns the class of the current state for the likelihood + State get_state() const { return state; } + + State *mutable_state() { return &state; } + + //! Returns a vector storing the state in its unconstrained form + Eigen::VectorXd get_unconstrained_state() override { + return internal::get_unconstrained_state(state, 0); + } + + //! Updates the state of the likelihood with the object given as input + void set_state(const State &state_, bool update_card = true) { + state = state_; + if (update_card) { + set_card(state.card); + } + }; + + void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) override { + State new_state; + new_state.set_from_proto(downcast_state(state_), update_card); + set_state(new_state, update_card); + } + + //! Updates the state of the likelihood starting from its unconstrained form + void set_state_from_unconstrained( + const Eigen::VectorXd &unconstrained_state) override { + internal::set_state_from_unconstrained(state, unconstrained_state, 0); + } + + //! Sets the (pointer to) the dataset in the cluster + void set_dataset(const Eigen::MatrixXd *const dataset) override { + dataset_ptr = dataset; + } + + //! Returns the (pointer to) the dataset in the cluster + const Eigen::MatrixXd *get_dataset() const { return dataset_ptr; } + + //! Adds a datum and its index to the likelihood + void add_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + + //! Removes a datum and its index from the likelihood + void remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + + //! Resets cardinality and indexes of data in this cluster + void clear_data() { + set_card(0); + cluster_data_idx = std::set(); + } + + protected: + //! Sets the cardinality of the cluster + void set_card(const int card_) { + card = card_; + log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); + } + + //! Down-casts the given generic proto message to a ClusterState proto + bayesmix::AlgorithmState::ClusterState *downcast_state( + google::protobuf::Message *state_) const { + return google::protobuf::internal::down_cast< + bayesmix::AlgorithmState::ClusterState *>(state_); + } + + //! Down-casts the given generic proto message to a ClusterState proto + const bayesmix::AlgorithmState::ClusterState &downcast_state( + const google::protobuf::Message &state_) const { + return google::protobuf::internal::down_cast< + const bayesmix::AlgorithmState::ClusterState &>(state_); + } + + //! Current state of this cluster + State state; + + //! Current cardinality of this cluster + int card = 0; + + //! Logarithm of current cardinality of this cluster + double log_card = stan::math::NEGATIVE_INFTY; + + //! Set of indexes of data points belonging to this cluster + std::set cluster_data_idx; + + //! Pointer to the cluster dataset + const Eigen::MatrixXd *dataset_ptr; +}; + +template +void BaseLikelihood::add_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) { + assert(cluster_data_idx.find(id) == cluster_data_idx.end()); + set_card(++card); + static_cast(this)->update_summary_statistics(datum, covariate, + true); + cluster_data_idx.insert(id); +} + +template +void BaseLikelihood::remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) { + static_cast(this)->update_summary_statistics(datum, covariate, + false); + set_card(--card); + auto it = cluster_data_idx.find(id); + assert(it != cluster_data_idx.end()); + cluster_data_idx.erase(it); +} + +template +void BaseLikelihood::write_state_to_proto( + google::protobuf::Message *out) const { + auto *out_cast = downcast_state(out); + out_cast->CopyFrom(state.get_as_proto()); + out_cast->set_cardinality(card); +} + +template +Eigen::VectorXd BaseLikelihood::lpdf_grid( + const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { + Eigen::VectorXd lpdf(data.rows()); + covariates_getter cov_getter(covariates); + for (int i = 0; i < data.rows(); i++) + lpdf(i) = + static_cast(this)->lpdf(data.row(i), cov_getter(i)); + + return lpdf; +} + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/fa_likelihood.cc b/src/hierarchies/likelihoods/fa_likelihood.cc new file mode 100644 index 000000000..81973647d --- /dev/null +++ b/src/hierarchies/likelihoods/fa_likelihood.cc @@ -0,0 +1,21 @@ +#include "fa_likelihood.h" + +#include "src/utils/distributions.h" + +void FALikelihood::clear_summary_statistics() { + data_sum = Eigen::VectorXd::Zero(dim); +} + +double FALikelihood::compute_lpdf(const Eigen::RowVectorXd& datum) const { + return bayesmix::multi_normal_lpdf_woodbury_chol( + datum, state.mu, state.psi_inverse, state.cov_wood, state.cov_logdet); +} + +void FALikelihood::update_sum_stats(const Eigen::RowVectorXd& datum, + bool add) { + if (add) { + data_sum += datum; + } else { + data_sum -= datum; + } +} diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h new file mode 100644 index 000000000..3e2e08e40 --- /dev/null +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -0,0 +1,53 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_FA_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_FA_LIKELIHOOD_H_ + +#include + +#include +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +/** + * A gaussian factor analytic likelihood, using the `State::FA` state. + * Represents the model: + * + * \f[ + * \bm{y}_1,\dots,\bm{y}_k \stackrel{\small\mathrm{iid}}{\sim} N_p(\bm{\mu}, + * \Sigma + \Lambda\Lambda^T), \f] + * + * where Lambda is a \f$ p \times d \f$ matrix, usually \f$ d << p \f$ and \f$ + * \Sigma \f$ is a diagonal matrix. Parameters are stored in a `State::FA` + * state. We store as summary statistics the sum of the \f$ \bm{y}_i \f$'s, but + * it is not sufficient for all the updates involved. Therefore, all the + * observations allocated to a cluster are processed when computing the + * cluster lpdf. + */ + +class FALikelihood : public BaseLikelihood { + public: + FALikelihood() = default; + ~FALikelihood() = default; + bool is_multivariate() const override { return true; }; + bool is_dependent() const override { return false; }; + void clear_summary_statistics() override; + void set_dim(unsigned int dim_) { + dim = dim_; + clear_summary_statistics(); + }; + unsigned int get_dim() const { return dim; }; + Eigen::VectorXd get_data_sum() const { return data_sum; }; + + protected: + double compute_lpdf(const Eigen::RowVectorXd& datum) const override; + void update_sum_stats(const Eigen::RowVectorXd& datum, bool add) override; + + unsigned int dim; + Eigen::VectorXd data_sum; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_FA_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/laplace_likelihood.cc b/src/hierarchies/likelihoods/laplace_likelihood.cc new file mode 100644 index 000000000..7d99a7efe --- /dev/null +++ b/src/hierarchies/likelihoods/laplace_likelihood.cc @@ -0,0 +1,6 @@ +#include "laplace_likelihood.h" + +double LaplaceLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::double_exponential_lpdf( + datum(0), state.mean, stan::math::sqrt(state.var / 2.0)); +} diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h new file mode 100644 index 000000000..9d4c25128 --- /dev/null +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -0,0 +1,63 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_LAPLACE_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_LAPLACE_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +/** + * A univariate Laplace likelihood, using the `State::UniLS` state. Represents + * the model: + * + * \f[ + * y_1,\dots,y_k \mid \mu, \sigma^2 \stackrel{\small\mathrm{iid}}{\sim} + * Laplace(\mu,\sigma^2), \f] + * + * where \f$ \mu \f$ is the mean and center of the distribution + * and \f$ \sigma^2 \f$ is the variance. The scale parameter \f$ \lambda \f$ is + * then \f$ \sqrt{\sigma^2/2} \f$. These parameters are stored in a + * `State::UniLS` state. Since the Laplace likelihood does not have sufficient + * statistics other than the whole sample, the `update_sum_stats()` method does + * nothing. + */ + +class LaplaceLikelihood + : public BaseLikelihood { + public: + LaplaceLikelihood() = default; + ~LaplaceLikelihood() = default; + bool is_multivariate() const override { return false; }; + bool is_dependent() const override { return false; }; + void clear_summary_statistics() override { return; }; + + template + T cluster_lpdf_from_unconstrained( + const Eigen::Matrix &unconstrained_params) const { + assert(unconstrained_params.size() == 2); + + T mean = unconstrained_params(0); + T var = stan::math::positive_constrain(unconstrained_params(1)); + + T out = 0.; + for (auto it = cluster_data_idx.begin(); it != cluster_data_idx.end(); + ++it) { + out += stan::math::double_exponential_lpdf(dataset_ptr->row(*it), mean, + stan::math::sqrt(var / 2.0)); + } + return out; + } + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override { + return; + }; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_LAPLACE_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/likelihood_internal.h b/src/hierarchies/likelihoods/likelihood_internal.h new file mode 100644 index 000000000..364a43ca9 --- /dev/null +++ b/src/hierarchies/likelihoods/likelihood_internal.h @@ -0,0 +1,61 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_LIKELIHOOD_INTERNAL_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_LIKELIHOOD_INTERNAL_H_ + +//! These functions exploit SFINAE to manage exception handling in all methods +//! required only if end user wants to rely on Metropolis-like updaters. SFINAE +//! (Substitution Failure Is Not An Error) is a C++ rule that applies during +//! overload resolution of function templates: When substituting the explicitly +//! specified or deduced type for the template parameter fails, the +//! specialization is discarded from the overload set instead of causing a +//! compile error. This feature is used in template metaprogramming. + +namespace internal { + +/* SFINAE for cluster_lpdf_from_unconstrained() */ +template +auto cluster_lpdf_from_unconstrained( + const Like &like, Eigen::Matrix unconstrained_params, + int) + -> decltype(like.template cluster_lpdf_from_unconstrained( + unconstrained_params)) { + return like.template cluster_lpdf_from_unconstrained( + unconstrained_params); +} +template +auto cluster_lpdf_from_unconstrained( + const Like &like, Eigen::Matrix unconstrained_params, + double) -> T { + throw(std::runtime_error( + "cluster_lpdf_from_unconstrained() not yet implemented")); +} + +/* SFINAE for get_unconstrained_state() */ +template +auto get_unconstrained_state(const State &state, int) + -> decltype(state.get_unconstrained()) { + return state.get_unconstrained(); +} +template +auto get_unconstrained_state(const State &state, double) -> Eigen::VectorXd { + throw(std::runtime_error("get_unconstrained_state() not yet implemented")); +} + +/* SFINAE for set_state_from_unconstrained() */ +template +auto set_state_from_unconstrained(State &state, + const Eigen::VectorXd &unconstrained_state, + int) + -> decltype(state.set_from_unconstrained(unconstrained_state)) { + state.set_from_unconstrained(unconstrained_state); +} +template +auto set_state_from_unconstrained(State &state, + const Eigen::VectorXd &unconstrained_state, + double) -> void { + throw(std::runtime_error( + "set_state_from_unconstrained() not yet implemented")); +} + +} // namespace internal + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_LIKELIHOOD_INTERNAL_H_ diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.cc b/src/hierarchies/likelihoods/multi_norm_likelihood.cc new file mode 100644 index 000000000..f0cfae90d --- /dev/null +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.cc @@ -0,0 +1,31 @@ +#include "multi_norm_likelihood.h" + +#include "src/utils/distributions.h" +#include "src/utils/eigen_utils.h" +#include "src/utils/proto_utils.h" + +double MultiNormLikelihood::compute_lpdf( + const Eigen::RowVectorXd &datum) const { + return bayesmix::multi_normal_prec_lpdf(datum, state.mean, state.prec_chol, + state.prec_logdet); +} + +void MultiNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + bool add) { + // Check if dim is not defined yet (this usually doesn't happen if the + // hierarchy is initialized) + if (!dim) set_dim(datum.size()); + // Updates + if (add) { + data_sum += datum.transpose(); + data_sum_squares += datum.transpose() * datum; + } else { + data_sum -= datum.transpose(); + data_sum_squares -= datum.transpose() * datum; + } +} + +void MultiNormLikelihood::clear_summary_statistics() { + data_sum = Eigen::VectorXd::Zero(dim); + data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); +} diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h new file mode 100644 index 000000000..6e249a75e --- /dev/null +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -0,0 +1,54 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_MULTI_NORM_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_MULTI_NORM_LIKELIHOOD_H_ + +#include + +#include +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +/** + * A multivariate normal likelihood, using the `State::MultiLS` state. + * Represents the model: + * + * \f[ + * \bm{y}_1,\dots, \bm{y}_k \stackrel{\small\mathrm{iid}}{\sim} + * N_p(\bm{\mu}, \Sigma), \f] + * + * where \f$ (\bm{\mu}, \Sigma) \f$ are stored in a `State::MultiLS` state. + * The sufficient statistics stored are the sum of the \f$ \bm{y}_i \f$'s + * and the sum of \f$ \bm{y}_i^T \bm{y}_i \f$. + */ + +class MultiNormLikelihood + : public BaseLikelihood { + public: + MultiNormLikelihood() = default; + ~MultiNormLikelihood() = default; + bool is_multivariate() const override { return true; }; + bool is_dependent() const override { return false; }; + void clear_summary_statistics() override; + + void set_dim(unsigned int dim_) { + dim = dim_; + clear_summary_statistics(); + }; + unsigned int get_dim() const { return dim; }; + Eigen::VectorXd get_data_sum() const { return data_sum; }; + Eigen::MatrixXd get_data_sum_squares() const { return data_sum_squares; }; + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + + unsigned int dim; + Eigen::VectorXd data_sum; + Eigen::MatrixXd data_sum_squares; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_MULTI_NORM_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt new file mode 100644 index 000000000..933c337b4 --- /dev/null +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -0,0 +1,8 @@ +target_sources(bayesmix PUBLIC + includes.h + base_state.h + uni_ls_state.h + multi_ls_state.h + uni_lin_reg_ls_state.h + fa_state.h +) diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h new file mode 100644 index 000000000..dbaae301b --- /dev/null +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -0,0 +1,69 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_BASE_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_BASE_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "src/utils/proto_utils.h" + +namespace State { + +/** + * Abstract base class for a generic state + * + * Given a statistical model with likelihood \f$ L(y \mid \tau) \f$ and prior + * \f$ p(\tau) \f$ a State class represents the value of tau at a certain MCMC + * iteration. In addition, each instance stores the cardinality of the number + * of observations in the model. + * + * State classes inheriting from this one should implement the methods + * `set_from_proto()` and `to_proto()`, that are used to deserialize from + * (and serialize to) a `bayesmix::AlgorithmState::ClusterState` + * protocol buffer message. + * + * Optionally, each state can have an "unconstrained" representation, + * where a bijective transformation \f$ B \f$ is applied to \f$ \tau \f$, so + * that the image of \f$ B \f$ is in \f$ \mathbb{R}^d \f$ for some \f$ d \f$. + * This is essential for the default updaters such as `RandomWalkUpdater` + * and `MalaUpdater` to work, but is not necessary for other model-specific + * updaters. + * If such a representation is needed, child classes should also implement + * `get_unconstrained()`, `set_from_unconstrained()`, and `log_det_jac()`. + */ + +class BaseState { + public: + int card; + + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + //! Returns the unconstrained representation \f$ x = B(\tau) \f$ + virtual Eigen::VectorXd get_unconstrained() { return Eigen::VectorXd(0); } + + //! Sets the current state as \f$ \tau = B^{-1}(in) \f$ + //! @param in the unconstrained representation of the state + virtual void set_from_unconstrained(const Eigen::VectorXd &in) {} + + //! Returns the log determinant of the jacobian of \f$ B^{-1} \f$ + virtual double log_det_jac() { return -1; } + + //! Sets the current state from a protobuf object + //! @param state_ a bayesmix::AlgorithmState::ClusterState instance + //! @param update_card if true, the current cardinality is updated + virtual void set_from_proto(const ProtoState &state_, bool update_card) = 0; + + //! Returns a `bayesmix::AlgorithmState::ClusterState` representing the + //! current value of the state + virtual ProtoState get_as_proto() const = 0; + + //! Returns a shared pointer to `bayesmix::AlgorithmState::ClusterState` + //! representing the current value of the state + std::shared_ptr to_proto() const { + return std::make_shared(get_as_proto()); + } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_BASE_STATE_H_ diff --git a/src/hierarchies/likelihoods/states/fa_state.h b/src/hierarchies/likelihoods/states/fa_state.h new file mode 100644 index 000000000..77e02c176 --- /dev/null +++ b/src/hierarchies/likelihoods/states/fa_state.h @@ -0,0 +1,80 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_FA_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_FA_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "base_state.h" +#include "src/utils/distributions.h" +#include "src/utils/eigen_utils.h" +#include "src/utils/proto_utils.h" + +namespace State { + +/** + * State of a Factor Analytic model + * + * \f[ + * Y_i = \Lambda\bm{\eta}_i + \bm{\varepsilon} + * \f] + * + * where \f$ Y_i \f$ is a \f$ p \f$-dimensional vetor, \f$ \bm{\eta}_i \f$ is a + * \f$ d \f$-dimensional one, \f$ \Lambda \f$ is a \f$ p \times d \f$ matrix + * and \f$ \bm{\varepsilon} \f$ is an error term with mean zero and diagonal + * covariance matrix \f$ \psi \f$. + * + * For faster likelihood evaluation, we store also the `cov_wood` factor and + * the log determinant of the matrix \f$ \Lambda \Lambda^T + \psi \f$, see + * the `compute_wood_chol_and_logdet(...)` function for more details. + * + * The unconstrained representation for this state is not implemented. + */ + +class FA : public BaseState { + public: + Eigen::VectorXd mu, psi; + Eigen::MatrixXd eta, lambda, cov_wood; + Eigen::DiagonalMatrix psi_inverse; + double cov_logdet; + + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + + mu = bayesmix::to_eigen(state_.fa_state().mu()); + psi = bayesmix::to_eigen(state_.fa_state().psi()); + eta = bayesmix::to_eigen(state_.fa_state().eta()); + lambda = bayesmix::to_eigen(state_.fa_state().lambda()); + psi_inverse = psi.cwiseInverse().asDiagonal(); + compute_wood_factors(); + } + + //! Sets cov_logdet and cov_wood by calling + //! bayesmix::compute_wood_chol_and_logdet() + void compute_wood_factors() { + auto [cov_wood_, cov_logdet_] = + bayesmix::compute_wood_chol_and_logdet(psi_inverse, lambda); + cov_logdet = cov_logdet_; + cov_wood = cov_wood_; + } + + ProtoState get_as_proto() const override { + bayesmix::FAState state_; + bayesmix::to_proto(mu, state_.mutable_mu()); + bayesmix::to_proto(psi, state_.mutable_psi()); + bayesmix::to_proto(eta, state_.mutable_eta()); + bayesmix::to_proto(lambda, state_.mutable_lambda()); + + bayesmix::AlgorithmState::ClusterState out; + out.mutable_fa_state()->CopyFrom(state_); + return out; + } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_FACTOR_ANALYZERS_STATE_H_ diff --git a/src/hierarchies/likelihoods/states/includes.h b/src/hierarchies/likelihoods/states/includes.h new file mode 100644 index 000000000..f4f868c52 --- /dev/null +++ b/src/hierarchies/likelihoods/states/includes.h @@ -0,0 +1,9 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ + +#include "fa_state.h" +#include "multi_ls_state.h" +#include "uni_lin_reg_ls_state.h" +#include "uni_ls_state.h" + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h new file mode 100644 index 000000000..a09221363 --- /dev/null +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -0,0 +1,112 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "base_state.h" +#include "src/utils/proto_utils.h" + +namespace State { + +//! Returns the unonstrained parametrization from the +//! unconstrained one, i.e. [mean_in[0], B(prec_in[1])] +//! where B is the `stan::math::cov_matrix_free()` +//! transformation. +template +Eigen::Matrix multi_ls_to_unconstrained( + Eigen::Matrix mean_in, + Eigen::Matrix prec_in) { + Eigen::Matrix prec_out = + stan::math::cov_matrix_free(prec_in); + Eigen::Matrix out(mean_in.size() + prec_out.size()); + out << mean_in, prec_out; + return out; +} + +//! Returns the unonstrained parametrization from the +//! unconstrained one +template +std::tuple, + Eigen::Matrix> +multi_ls_to_constrained(Eigen::Matrix in) { + double dim_ = 0.5 * (std::sqrt(8 * in.size() + 9) - 3); + double dimf; + assert(modf(dim_, &dimf) == 0.0); + int dim = int(dimf); + Eigen::Matrix mean(dim); + mean << in.head(dim); + Eigen::Matrix prec(dim, dim); + prec = stan::math::cov_matrix_constrain(in.tail(in.size() - dim), dim); + return std::make_tuple(mean, prec); +} + +//! Returns the log determinant of the jacobian of the map +//! (x, y) -> (x, B(y)), that is the inverse map of the +//! constrained -> unconstrained representation. +template +T multi_ls_log_det_jac( + Eigen::Matrix prec_constrained) { + T out = 0; + stan::math::cov_matrix_constrain( + stan::math::cov_matrix_free(prec_constrained), out); + return out; +} + +//! A univariate location-scale state with parametrization (mean, Cov) +//! where Cov is the covariance matrix. +//! The unconstrained representation corresponds to (mean, B(cov)), where +//! B is the `stan::math::cov_matrix_free()` transformation. +class MultiLS : public BaseState { + public: + Eigen::VectorXd mean; + Eigen::MatrixXd prec, prec_chol; + double prec_logdet; + + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + Eigen::VectorXd get_unconstrained() override { + return multi_ls_to_unconstrained(mean, prec); + } + + void set_from_unconstrained(const Eigen::VectorXd &in) override { + std::tie(mean, prec) = multi_ls_to_constrained(in); + set_from_constrained(mean, prec); + } + + void set_from_constrained(Eigen::VectorXd mean_, Eigen::MatrixXd prec_) { + mean = mean_; + prec = prec_; + prec_chol = Eigen::LLT(prec).matrixL(); + Eigen::VectorXd diag = prec_chol.diagonal(); + prec_logdet = 2 * log(diag.array()).sum(); + } + + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + + mean = to_eigen(state_.multi_ls_state().mean()); + prec = to_eigen(state_.multi_ls_state().prec()); + prec_chol = to_eigen(state_.multi_ls_state().prec_chol()); + Eigen::VectorXd diag = prec_chol.diagonal(); + prec_logdet = 2 * log(diag.array()).sum(); + } + + ProtoState get_as_proto() const override { + ProtoState state; + bayesmix::to_proto(mean, state.mutable_multi_ls_state()->mutable_mean()); + bayesmix::to_proto(prec, state.mutable_multi_ls_state()->mutable_prec()); + bayesmix::to_proto(prec_chol, + state.mutable_multi_ls_state()->mutable_prec_chol()); + return state; + } + + double log_det_jac() override { return multi_ls_log_det_jac(prec); } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h new file mode 100644 index 000000000..0b650c878 --- /dev/null +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h @@ -0,0 +1,101 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_LS_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_LS_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "base_state.h" +#include "src/utils/eigen_utils.h" +#include "src/utils/proto_utils.h" + +namespace State { + +//! Returns the constrained parametrization from the +//! unconstrained one, i.e. [a, exp(b)], +//! where `a` is equal to the vector `in` excluding its last +//! element, and `b` is the last element in `in` +template +Eigen::Matrix uni_lin_reg_to_constrained( + Eigen::Matrix in) { + int N = in.size(); + Eigen::Matrix out(N); + out << in.head(N - 1), stan::math::exp(in(N - 1)); + return out; +} + +//! Returns the unconstrained parametrization from the +//! constrained one, i.e. [a, log(b)] +//! where `a` is equal to the vector `in` excluding its last +//! element, and `b` is the last element in `in` +template +Eigen::Matrix uni_lin_reg_to_unconstrained( + Eigen::Matrix in) { + int N = in.size(); + Eigen::Matrix out(N); + out << in.head(N - 1), stan::math::log(in(N - 1)); + return out; +} + +//! Returns the log determinant of the jacobian of the map +//! (x, y) -> (x, log(y)), that is the inverse map of the +//! constrained -> unconstrained representation. +template +T uni_lin_reg_log_det_jac(Eigen::Matrix constrained) { + T out = 0; + int N = constrained.size(); + stan::math::positive_constrain(stan::math::log(constrained(N - 1)), out); + return out; +} + +//! State of a scalar linear regression model with parameters +//! (regression_coeffs, var), where var is the variance of the error term. +//! The unconstrained representation is (regression_coeffs, log(var)). +class UniLinRegLS : public BaseState { + public: + Eigen::VectorXd regression_coeffs; + double var; + + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + Eigen::VectorXd get_unconstrained() override { + Eigen::VectorXd temp(regression_coeffs.size() + 1); + temp << regression_coeffs, var; + return uni_lin_reg_to_unconstrained(temp); + } + + void set_from_unconstrained(const Eigen::VectorXd &in) override { + Eigen::VectorXd temp = uni_lin_reg_to_constrained(in); + int dim = in.size() - 1; + regression_coeffs = temp.head(dim); + var = temp(dim); + } + + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + regression_coeffs = + bayesmix::to_eigen(state_.lin_reg_uni_ls_state().regression_coeffs()); + var = state_.lin_reg_uni_ls_state().var(); + } + + ProtoState get_as_proto() const override { + ProtoState state; + bayesmix::to_proto( + regression_coeffs, + state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()); + state.mutable_lin_reg_uni_ls_state()->set_var(var); + return state; + } + + double log_det_jac() override { + Eigen::VectorXd temp(regression_coeffs.size() + 1); + temp << regression_coeffs, var; + return uni_lin_reg_log_det_jac(temp); + } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_LS_STATE_H_ diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h new file mode 100644 index 000000000..c553ef60d --- /dev/null +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -0,0 +1,87 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "base_state.h" +#include "src/utils/proto_utils.h" + +namespace State { + +//! Returns the constrained parametrization from the +//! unconstrained one, i.e. [in[0], exp(in[1])] +template +Eigen::Matrix uni_ls_to_constrained( + Eigen::Matrix in) { + Eigen::Matrix out(2); + out << in(0), stan::math::exp(in(1)); + return out; +} + +//! Returns the unconstrained parametrization from the +//! constrained one, i.e. [in[0], log(in[1])] +template +Eigen::Matrix uni_ls_to_unconstrained( + Eigen::Matrix in) { + Eigen::Matrix out(2); + out << in(0), stan::math::log(in(1)); + return out; +} + +//! Returns the log determinant of the jacobian of the map +//! (x, y) -> (x, log(y)), that is the inverse map of the +//! constrained -> unconstrained representation. +template +T uni_ls_log_det_jac(Eigen::Matrix constrained) { + T out = 0; + stan::math::positive_constrain(stan::math::log(constrained(1)), out); + return out; +} + +//! A univariate location-scale state with parametrization (mean, var) +//! The unconstrained representation corresponds to (mean, log(var)) +class UniLS : public BaseState { + public: + double mean, var; + + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + Eigen::VectorXd get_unconstrained() override { + Eigen::VectorXd temp(2); + temp << mean, var; + return uni_ls_to_unconstrained(temp); + } + + void set_from_unconstrained(const Eigen::VectorXd &in) override { + Eigen::VectorXd temp = uni_ls_to_constrained(in); + mean = temp(0); + var = temp(1); + } + + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + mean = state_.uni_ls_state().mean(); + var = state_.uni_ls_state().var(); + } + + ProtoState get_as_proto() const override { + ProtoState state; + state.mutable_uni_ls_state()->set_mean(mean); + state.mutable_uni_ls_state()->set_var(var); + return state; + } + + double log_det_jac() override { + Eigen::VectorXd temp(2); + temp << mean, var; + return uni_ls_log_det_jac(temp); + } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc new file mode 100644 index 000000000..550eec5fd --- /dev/null +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc @@ -0,0 +1,30 @@ +#include "uni_lin_reg_likelihood.h" + +#include "src/utils/eigen_utils.h" + +void UniLinRegLikelihood::clear_summary_statistics() { + mixed_prod = Eigen::VectorXd::Zero(dim); + data_sum_squares = 0.0; + covar_sum_squares = Eigen::MatrixXd::Zero(dim, dim); +} + +double UniLinRegLikelihood::compute_lpdf( + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + return stan::math::normal_lpdf( + datum(0), state.regression_coeffs.dot(covariate), sqrt(state.var)); +} + +void UniLinRegLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) { + if (add) { + data_sum_squares += datum(0) * datum(0); + covar_sum_squares += covariate.transpose() * covariate; + mixed_prod += datum(0) * covariate.transpose(); + } else { + data_sum_squares -= datum(0) * datum(0); + covar_sum_squares -= covariate.transpose() * covariate; + mixed_prod -= datum(0) * covariate.transpose(); + } +} diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h new file mode 100644 index 000000000..bb9e55687 --- /dev/null +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h @@ -0,0 +1,63 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_LIN_REG_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_LIN_REG_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +/** + * A scalar linear regression model, using the `State::UniLinRegLS` state. + * Represents the model: + * + * \f[ + * y_i \mid \bm{x}_i, \bm{\beta}, \sigma^2 + * \stackrel{\small\mathrm{ind}}{\sim} N(\bm{x}_i^T\bm{\beta},\sigma^2), \f] + * + * where \f$ (\bm{\beta}, \sigma^2) \f$ are stored in a `State::UniLinRegLS` + * state. The sufficient statistics stored are the sum of \f$ y_i^2 \f$, the + * sum of \f$ \bm{x}_i^T \bm{x}_i \f$ and the sum of \f$ y_i \bm{x}_i^T \f$. + */ + +class UniLinRegLikelihood + : public BaseLikelihood { + public: + UniLinRegLikelihood() = default; + ~UniLinRegLikelihood() = default; + bool is_multivariate() const override { return false; }; + bool is_dependent() const override { return true; }; + void clear_summary_statistics() override; + + // Getters and Setters + unsigned int get_dim() const { return dim; }; + void set_dim(unsigned int dim_) { + dim = dim_; + clear_summary_statistics(); + }; + double get_data_sum_squares() const { return data_sum_squares; }; + Eigen::MatrixXd get_covar_sum_squares() const { return covar_sum_squares; }; + Eigen::VectorXd get_mixed_prod() const { return mixed_prod; }; + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) override; + + // Dimension of the coefficients vector + unsigned int dim; + // Represents pieces of y^t y + double data_sum_squares; + // Represents pieces of X^T X + Eigen::MatrixXd covar_sum_squares; + // Represents pieces of X^t y + Eigen::VectorXd mixed_prod; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_LIN_REG_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc new file mode 100644 index 000000000..3b5cdf06e --- /dev/null +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -0,0 +1,21 @@ +#include "uni_norm_likelihood.h" + +double UniNormLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); +} + +void UniNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + bool add) { + if (add) { + data_sum += datum(0); + data_sum_squares += datum(0) * datum(0); + } else { + data_sum -= datum(0); + data_sum_squares -= datum(0) * datum(0); + } +} + +void UniNormLikelihood::clear_summary_statistics() { + data_sum = 0; + data_sum_squares = 0; +} diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h new file mode 100644 index 000000000..e278a3635 --- /dev/null +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -0,0 +1,58 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_NORM_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_NORM_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +/** + * A univariate normal likelihood, using the `State::UniLS` state. Represents + * the model: + * + * \f[ + * y_1, \dots, y_k \mid \mu, \sigma^2 \stackrel{\small\mathrm{iid}}{\sim} + * N(\mu, \sigma^2), \f] + * + * where \f$ (\mu, \sigma^2) \f$ are stored in a `State::UniLS` state. + * The sufficient statistics stored are the sum of the \f$ y_i \f$'s and the + * sum of \f$ y_i^2 \f$. + */ + +class UniNormLikelihood + : public BaseLikelihood { + public: + UniNormLikelihood() = default; + ~UniNormLikelihood() = default; + bool is_multivariate() const override { return false; }; + bool is_dependent() const override { return false; }; + void clear_summary_statistics() override; + double get_data_sum() const { return data_sum; }; + double get_data_sum_squares() const { return data_sum_squares; }; + + template + T cluster_lpdf_from_unconstrained( + const Eigen::Matrix &unconstrained_params) const { + assert(unconstrained_params.size() == 2); + T mean = unconstrained_params(0); + T var = stan::math::positive_constrain(unconstrained_params(1)); + T out = -(data_sum_squares - 2 * mean * data_sum + card * mean * mean) / + (2 * var); + out -= card * 0.5 * stan::math::log(stan::math::TWO_PI * var); + return out; + } + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + + double data_sum = 0; + double data_sum_squares = 0; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_NORM_LIKELIHOOD_H_ diff --git a/src/hierarchies/lin_reg_uni_hierarchy.cc b/src/hierarchies/lin_reg_uni_hierarchy.cc deleted file mode 100644 index 89c469c96..000000000 --- a/src/hierarchies/lin_reg_uni_hierarchy.cc +++ /dev/null @@ -1,166 +0,0 @@ -#include "lin_reg_uni_hierarchy.h" - -#include -#include -#include - -#include "src/utils/eigen_utils.h" -#include "src/utils/proto_utils.h" -#include "src/utils/rng.h" - -double LinRegUniHierarchy::like_lpdf( - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - return stan::math::normal_lpdf( - datum(0), state.regression_coeffs.dot(covariate), sqrt(state.var)); -} - -double LinRegUniHierarchy::marg_lpdf( - const LinRegUni::Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - double sig_n = sqrt( - (1 + (covariate * params.var_scaling_inv * covariate.transpose())(0)) * - params.scale / params.shape); - return stan::math::student_t_lpdf(datum(0), 2 * params.shape, - covariate.dot(params.mean), sig_n); -} - -void LinRegUniHierarchy::initialize_state() { - state.regression_coeffs = hypers->mean; - state.var = hypers->scale / (hypers->shape + 1); -} - -void LinRegUniHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); - dim = hypers->mean.size(); - hypers->var_scaling = - bayesmix::to_eigen(prior->fixed_values().var_scaling()); - hypers->var_scaling_inv = stan::math::inverse_spd(hypers->var_scaling); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - // Check validity - if (dim != hypers->var_scaling.rows()) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - bayesmix::check_spd(hypers->var_scaling); - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void LinRegUniHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -LinRegUni::State LinRegUniHierarchy::draw( - const LinRegUni::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - LinRegUni::State out; - out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - out.regression_coeffs = stan::math::multi_normal_prec_rng( - params.mean, params.var_scaling / out.var, rng); - return out; -} - -void LinRegUniHierarchy::update_summary_statistics( - const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate, - const bool add) { - if (add) { - data_sum_squares += datum(0) * datum(0); - covar_sum_squares += covariate.transpose() * covariate; - mixed_prod += datum(0) * covariate.transpose(); - } else { - data_sum_squares -= datum(0) * datum(0); - covar_sum_squares -= covariate.transpose() * covariate; - mixed_prod -= datum(0) * covariate.transpose(); - } -} - -void LinRegUniHierarchy::clear_summary_statistics() { - mixed_prod = Eigen::VectorXd::Zero(dim); - data_sum_squares = 0.0; - covar_sum_squares = Eigen::MatrixXd::Zero(dim, dim); -} - -LinRegUni::Hyperparams LinRegUniHierarchy::compute_posterior_hypers() const { - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - LinRegUni::Hyperparams post_params; - post_params.var_scaling = covar_sum_squares + hypers->var_scaling; - auto llt = post_params.var_scaling.llt(); - post_params.var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, dim)); - post_params.mean = - llt.solve(mixed_prod + hypers->var_scaling * hypers->mean); - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = - hypers->scale + - 0.5 * (data_sum_squares + - hypers->mean.transpose() * hypers->var_scaling * hypers->mean - - post_params.mean.transpose() * post_params.var_scaling * - post_params.mean); - return post_params; -} - -void LinRegUniHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.regression_coeffs = - bayesmix::to_eigen(statecast.lin_reg_uni_ls_state().regression_coeffs()); - state.var = statecast.lin_reg_uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -LinRegUniHierarchy::get_state_proto() const { - bayesmix::LinRegUniLSState state_; - bayesmix::to_proto(state.regression_coeffs, - state_.mutable_regression_coeffs()); - state_.set_var(state.var); - - auto out = std::make_shared(); - out->mutable_lin_reg_uni_ls_state()->CopyFrom(state_); - return out; -} - -void LinRegUniHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).lin_reg_uni_state(); - hypers->mean = bayesmix::to_eigen(hyperscast.mean()); - hypers->var_scaling = bayesmix::to_eigen(hyperscast.var_scaling()); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); -} - -std::shared_ptr -LinRegUniHierarchy::get_hypers_proto() const { - bayesmix::MultiNormalIGDistribution hypers_; - bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); - bayesmix::to_proto(hypers->var_scaling, hypers_.mutable_var_scaling()); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - - auto out = std::make_shared(); - out->mutable_lin_reg_uni_state()->CopyFrom(hypers_); - return out; -} diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index c55a2f7c5..2f9a823bf 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -1,144 +1,83 @@ #ifndef BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" +#include "base_hierarchy.h" #include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Linear regression hierarchy for univariate data. - -//! This class implements a dependent hierarchy which represents the classical -//! univariate Bayesian linear regression model, i.e.: -//! y_i | \beta, x_i, \sigma^2 \sim N(\beta^T x_i, sigma^2) -//! \beta | \sigma^2 \sim N(\mu, sigma^2 Lambda^{-1}) -//! \sigma^2 \sim InvGamma(a, b) -//! -//! The state consists of the `regression_coeffs` \beta, and the `var` sigma^2. -//! Lambda is called the variance-scaling factor. For more information, please -//! refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace LinRegUni { -//! Custom container for State values -struct State { - Eigen::VectorXd regression_coeffs; - double var; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - Eigen::VectorXd mean; - Eigen::MatrixXd var_scaling; - Eigen::MatrixXd var_scaling_inv; - double shape; - double scale; -}; -} // namespace LinRegUni +#include "likelihoods/uni_lin_reg_likelihood.h" +#include "priors/mnig_prior_model.h" +#include "updaters/mnig_updater.h" + +/** + * Linear regression hierarchy for univariate data. + * + * This class implements a dependent hierarchy which represents the classical + * univariate Bayesian linear regression model, i.e.: + * + * \f[ + * f(y_i \mid \bm{x}_i,\mu,\sigma^2) &= N(\bm{\beta}^T \bm{x}_i, \sigma^2) + * \\ + * \bm{\beta} \mid \sigma^2 &\sim N_p(\bm{\mu}, \sigma^2 \Lambda^{-1}) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * The state consists of the `regression_coeffs` \f$ \bm{\beta} \f$, and the + * `var` \f$ \sigma^2 \f$. \f$ \Lambda \f$ is called the variance-scaling + * factor. Note that this hierarchy is conjugate, thus the marginal + * distribution is available in closed form. For more information, please refer + * to the parent class `BaseHierarchy`, to the class `UniLinRegLikelihood` for + * details on the likelihood model and to `MNIGPriorModel` for details on the + * prior model. + */ class LinRegUniHierarchy - : public ConjugateHierarchy { + : public BaseHierarchy { public: LinRegUniHierarchy() = default; ~LinRegUniHierarchy() = default; - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - LinRegUni::State draw(const LinRegUni::Hyperparams ¶ms); - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param covariate Covariate vector associated to datum - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate, - const bool add) override; - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::LinRegUni; } - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Returns the dimension of the coefficients vector - unsigned int get_dim() const { return dim; } - - //! Computes and return posterior hypers given data currently in this cluster - LinRegUni::Hyperparams compute_posterior_hypers() const; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - //! Returns whether the hierarchy depends on covariate values or not - bool is_dependent() const override { return true; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - double marg_lpdf(const LinRegUni::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const override; + //! Sets the default updater algorithm for this hierarchy + void set_default_updater() { updater = std::make_shared(); } //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Dimension of the coefficients vector - unsigned int dim; + void initialize_state() override { + // Initialize likelihood dimension to prior one + like->set_dim(prior->get_dim()); + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::UniLinRegLS state; + state.regression_coeffs = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + like->set_state(state); + }; - //! Represents pieces of y^t y - double data_sum_squares; - - //! Represents pieces of X^T X - Eigen::MatrixXd covar_sum_squares; - - //! Represents pieces of X^t y - Eigen::VectorXd mixed_prod; + //! Evaluates the log-marginal distribution of data in a single point + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vectors associated to data + //! @return The evaluation of the lpdf + double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const override { + auto params = hier_params->lin_reg_uni_state(); + Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); + Eigen::MatrixXd var_scaling = bayesmix::to_eigen(params.var_scaling()); + + auto I = Eigen::MatrixXd::Identity(prior->get_dim(), prior->get_dim()); + Eigen::MatrixXd var_scaling_inv = var_scaling.llt().solve(I); + + double sig_n = + sqrt((1 + (covariate * var_scaling_inv * covariate.transpose())(0)) * + params.scale() / params.shape()); + return stan::math::student_t_lpdf(datum(0), 2 * params.shape(), + covariate.dot(mean), sig_n); + }; }; #endif // BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index 48cff3a70..21ec46ae5 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -11,6 +11,7 @@ #include "lin_reg_uni_hierarchy.h" #include "nnig_hierarchy.h" #include "nnw_hierarchy.h" +#include "nnxig_hierarchy.h" #include "src/runtime/factory.h" //! Loads all available `Hierarchy` objects into the appropriate factory, so @@ -27,6 +28,9 @@ __attribute__((constructor)) static void load_hierarchies() { Builder NNIGbuilder = []() { return std::make_shared(); }; + Builder NNxIGbuilder = []() { + return std::make_shared(); + }; Builder NNWbuilder = []() { return std::make_shared(); }; @@ -40,11 +44,12 @@ __attribute__((constructor)) static void load_hierarchies() { return std::make_shared(); }; - factory.add_builder(LapNIGHierarchy().get_id(), LapNIGbuilder); - factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder); + factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); + factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); factory.add_builder(FAHierarchy().get_id(), FAbuilder); + factory.add_builder(LapNIGHierarchy().get_id(), LapNIGbuilder); } #endif // BAYESMIX_HIERARCHIES_LOAD_HIERARCHIES_H_ diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index 325389f09..c22393884 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -1,122 +1,69 @@ #ifndef BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" +#include "base_hierarchy.h" #include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Conjugate Normal Normal-InverseGamma hierarchy for univariate data. - -//! This class represents a hierarchical model where data are distributed -//! according to a normal likelihood, the parameters of which have a -//! Normal-InverseGamma centering distribution. That is: -//! f(x_i|mu,sig) = N(mu,sig^2) -//! (mu,sig^2) ~ N-IG(mu0, lambda0, alpha0, beta0) -//! The state is composed of mean and variance. The state hyperparameters, -//! contained in the Hypers object, are (mu_0, lambda0, alpha0, beta0), all -//! scalar values. Note that this hierarchy is conjugate, thus the marginal -//! distribution is available in closed form. For more information, please -//! refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace NNIG { -//! Custom container for State values -struct State { - double mean, var; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double mean, var_scaling, shape, scale; -}; - -}; // namespace NNIG +#include "likelihoods/uni_norm_likelihood.h" +#include "priors/nig_prior_model.h" +#include "updaters/nnig_updater.h" + +/** + * Conjugate Normal Normal-InverseGamma hierarchy for univariate data. + * + * This class represents a hierarchical model where data are distributed + * according to a Normal likelihood (see the `UniNormLikelihood` class for + * details). The likelihood parameters have a Normal-InverseGamma centering + * distribution (see the `NIGPriorModel` class for details). That is: + * + * \f[ + * f(x_i \mid \mu, \sigma^2) &= N(\mu,\sigma^2) \\ + * (\mu,\sigma^2) & \sim NIG(\mu_0, \lambda_0, \alpha_0, \beta_0) + * \f] + * + * The state is composed of mean and variance. The state hyperparameters are + * \f$(\mu_0, \lambda_0, \alpha_0, \beta_0)\f$, all scalar values. Note that + * this hierarchy is conjugate, thus the marginal distribution is available in + * closed form + */ class NNIGHierarchy - : public ConjugateHierarchy { + : public BaseHierarchy { public: NNIGHierarchy() = default; ~NNIGHierarchy() = default; - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - NNIG::State draw(const NNIG::Hyperparams ¶ms); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNIG; } - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Computes and return posterior hypers given data currently in this cluster - NNIG::Hyperparams compute_posterior_hypers() const; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double marg_lpdf(const NNIG::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) override; + //! Sets the default updater algorithm for this hierarchy + void set_default_updater() { updater = std::make_shared(); } //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; + void initialize_state() override { + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::UniLS state; + state.mean = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + like->set_state(state); + }; - //! Sum of data points currently belonging to the cluster - double data_sum = 0; - - //! Sum of squared data points currently belonging to the cluster - double data_sum_squares = 0; + //! Evaluates the log-marginal distribution of data in a single point + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + double marg_lpdf(ProtoHypersPtr hier_params, + const Eigen::RowVectorXd &datum) const override { + auto params = hier_params->nnig_state(); + double sig_n = sqrt(params.scale() * (params.var_scaling() + 1) / + (params.shape() * params.var_scaling())); + return stan::math::student_t_lpdf(datum(0), 2 * params.shape(), + params.mean(), sig_n); + } }; #endif // BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index b3bd0afeb..2cf36464a 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -1,168 +1,102 @@ #ifndef BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" +#include "base_hierarchy.h" #include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Normal Normal-Wishart hierarchy for multivariate data. - -//! This class represents a hierarchy, i.e. a cluster, whose multivariate data -//! are distributed according to a multinomial normal likelihood, the -//! parameters of which have a Normal-Wishart centering distribution. That is: -//! f(x_i|mu,tau) = N(mu,tau^{-1}) -//! (mu,tau) ~ NW(mu0, lambda0, tau0, nu0) -//! The state is composed of mean and precision matrix. The Cholesky factor and -//! log-determinant of the latter are also included in the container for -//! efficiency reasons. The state's hyperparameters, contained in the Hypers -//! object, are (mu0, lambda0, tau0, nu0), which are respectively vector, -//! scalar, matrix, and scalar. Note that this hierarchy is conjugate, thus the -//! marginal distribution is available in closed form. For more information, -//! please refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace NNW { -//! Custom container for State values -struct State { - Eigen::VectorXd mean; - Eigen::MatrixXd prec; - Eigen::MatrixXd prec_chol; - double prec_logdet; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - Eigen::VectorXd mean; - double var_scaling; - double deg_free; - Eigen::MatrixXd scale; - Eigen::MatrixXd scale_inv; - Eigen::MatrixXd scale_chol; -}; -} // namespace NNW +#include "likelihoods/multi_norm_likelihood.h" +#include "priors/nw_prior_model.h" +#include "src/utils/distributions.h" +#include "updaters/nnw_updater.h" + +/** + * Normal Normal-Wishart hierarchy for multivariate data. + * + * This class represents a hierarchy whose multivariate data + * are distributed according to a multivariate normal likelihood (see the + * `MultiNormLikelihood` for details). The likelihood parameters have a + * Normal-Wishart centering distribution (see the `NWPriorModel` class for + * details). That is: + * + * \f[ + * f(\bm{x}_i \mid \bm{\mu},\Sigma) &= N_d(\bm{\mu},\Sigma^{-1}) \\ + * (\bm{\mu},\Sigma) &\sim NW(\mu_0, \lambda, \Psi_0, \nu_0) + * \f] + * The state is composed of mean and precision matrix. The Cholesky factor and + * log-determinant of the latter are also included in the container for + * efficiency reasons. The state's hyperparameters are \f$(\mu_0, \lambda, + * \Psi_0, \nu_0)\f$, which are respectively vector, scalar, matrix, and + * scalar. Note that this hierarchy is conjugate, thus the marginal + * distribution is available in closed form + */ class NNWHierarchy - : public ConjugateHierarchy { + : public BaseHierarchy { public: NNWHierarchy() = default; ~NNWHierarchy() = default; - // EVALUATION FUNCTIONS FOR GRIDS OF POINTS - //! Evaluates the log-likelihood of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = - Eigen::MatrixXd(0, 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - NNW::State draw(const NNW::Hyperparams ¶ms); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNW; } - //! Computes and return posterior hypers given data currently in this cluster - NNW::Hyperparams compute_posterior_hypers() const; - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return true; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double marg_lpdf(const NNW::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) override; - - //! Writes prec and its utilities to the given state object by pointer - void write_prec_to_state(const Eigen::MatrixXd &prec_, NNW::State *out); - - //! Returns parameters for the predictive Student's t distribution - NNW::Hyperparams get_predictive_t_parameters( - const NNW::Hyperparams ¶ms) const; + //! Sets the default updater algorithm for this hierarchy + void set_default_updater() { updater = std::make_shared(); } //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Dimension of data space - unsigned int dim; + void initialize_state() override { + // Initialize likelihood dimension to prior one + like->set_dim(prior->get_dim()); + // Get hypers and data dimension + auto hypers = prior->get_hypers(); + unsigned int dim = like->get_dim(); + // Initialize likelihood state + State::MultiLS state; + state.mean = hypers.mean; + prior->write_prec_to_state( + hypers.var_scaling * Eigen::MatrixXd::Identity(dim, dim), &state); + like->set_state(state); + }; - //! Sum of data points currently belonging to the cluster - Eigen::VectorXd data_sum; + //! Evaluates the log-marginal distribution of data in a single point + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + double marg_lpdf(ProtoHypersPtr hier_params, + const Eigen::RowVectorXd &datum) const override { + // TODO check Bayes rule for this hierarchy + HyperParams pred_params = get_predictive_t_parameters(hier_params); + Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); + double logdet = 2 * log(diag.array()).sum(); + return bayesmix::multi_student_t_invscale_lpdf( + datum, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, + logdet); + } - //! Sum of squared data points currently belonging to the cluster - Eigen::MatrixXd data_sum_squares; + //! Helper function that computes the predictive parameters for the + //! multivariate t distribution from the current hyperparameter values. It is + //! used to efficiently compute the log-marginal distribution of data. + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values + //! @return A `HyperParam` object with the predictive parameters + HyperParams get_predictive_t_parameters(ProtoHypersPtr hier_params) const { + auto params = hier_params->nnw_state(); + // Compute dof and scale of marginal distribution + unsigned int dim = like->get_dim(); + double nu_n = params.deg_free() - dim + 1; + double coeff = (params.var_scaling() + 1) / (params.var_scaling() * nu_n); + Eigen::MatrixXd scale_chol = + Eigen::LLT(bayesmix::to_eigen(params.scale())) + .matrixU(); + Eigen::MatrixXd scale_chol_n = scale_chol / std::sqrt(coeff); + // Return predictive t parameters + HyperParams out; + out.mean = bayesmix::to_eigen(params.mean()); + out.deg_free = nu_n; + out.scale_chol = scale_chol_n; + return out; + } }; #endif // BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h new file mode 100644 index 000000000..20aecebdd --- /dev/null +++ b/src/hierarchies/nnxig_hierarchy.h @@ -0,0 +1,56 @@ +#ifndef BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ + +#include "base_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "likelihoods/uni_norm_likelihood.h" +#include "priors/nxig_prior_model.h" +#include "updaters/nnxig_updater.h" + +/** + * Semi-conjugate Normal Normal x InverseGamma hierarchy for univariate data. + * + * This class represents a hierarchical model where data are distributed + * according to a Normal likelihood (see the `UniNormLikelihood` class for + * details). The likelihood parameters have a Normal x InverseGamma centering + * distribution (see the `NxIGPriorModel` class for details). That is: + * + * \f[ + * f(x_i \mid \mu,\sigma^2) &= N(\mu,\sigma^2) \\ + * \mu &\sim N(\mu_0, \eta^2) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * The state is composed of mean and variance. The state hyperparameters are + * \f$ (\mu_0, \eta^2, a, b) \f$, all scalar values. Note that this hierarchy + * is NOT conjugate, meaning that the marginal distribution is not available + * in closed form + */ + +class NNxIGHierarchy + : public BaseHierarchy { + public: + NNxIGHierarchy() = default; + ~NNxIGHierarchy() = default; + + //! Returns the Protobuf ID associated to this class + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::NNxIG; + } + + //! Sets the default updater algorithm for this hierarchy + void set_default_updater() { updater = std::make_shared(); } + + //! Initializes state parameters to appropriate values + void initialize_state() override { + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::UniLS state; + state.mean = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + like->set_state(state); + }; +}; + +#endif // BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt new file mode 100644 index 000000000..d6901ee65 --- /dev/null +++ b/src/hierarchies/priors/CMakeLists.txt @@ -0,0 +1,16 @@ +target_sources(bayesmix PUBLIC + prior_model_internal.h + abstract_prior_model.h + base_prior_model.h + hyperparams.h + nig_prior_model.h + nig_prior_model.cc + nxig_prior_model.h + nxig_prior_model.cc + nw_prior_model.h + nw_prior_model.cc + mnig_prior_model.h + mnig_prior_model.cc + fa_prior_model.h + fa_prior_model.cc +) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h new file mode 100644 index 000000000..42b7106a5 --- /dev/null +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -0,0 +1,114 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_ABSTRACT_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_ABSTRACT_PRIOR_MODEL_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "src/hierarchies/likelihoods/states/includes.h" +#include "src/utils/rng.h" + +//! Abstract class for a generic prior model +//! +//! This class is the basis for a curiously recurring template pattern (CRTP) +//! for `PriorModel` objects, ad it is solely composed of interface functions +//! for derived classes to use. +//! +//! A prior model represents the prior for the parameters in the likelihood. +//! Hence, it can evaluate the log probability density function (lpdf) for a +//! given parameter state. +//! +//! We also store a pointer to the protobuf object that represents the type of +//! prior used for the parameters in the likelihood. + +class AbstractPriorModel { + public: + // Useful type aliases + using ProtoHypersPtr = + std::shared_ptr; + using ProtoHypers = ProtoHypersPtr::element_type; + + //! Default destructor + virtual ~AbstractPriorModel() = default; + + //! Returns an independent, data-less copy of this object + virtual std::shared_ptr clone() const = 0; + + //! Returns an independent, data-less deep copy of this object + virtual std::shared_ptr deep_clone() const = 0; + + //! Evaluates the log likelihood for the prior model, given the state of the + //! cluster + //! @param state_ A Protobuf message storing the current state of the cluster + //! @return The evaluation of the log likelihood + virtual double lpdf(const google::protobuf::Message &state_) = 0; + + //! Evaluates the log likelihood for unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood of the prior model + virtual double lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) const { + throw std::runtime_error("lpdf_from_unconstrained() not yet implemented"); + } + + //! Evaluates the log likelihood for unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. This version using + //! `stan::math::var` type is required for Stan automatic aifferentiation. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood of the prior model + virtual stan::math::var lpdf_from_unconstrained( + Eigen::Matrix unconstrained_params) + const { + throw std::runtime_error( + "cluster_lpdf_from_unconstrained() not yet implemented"); + } + + //! Sampling from the prior model + //! @param hier_hypers A pointer to a + //! `bayesmix::AlgorithmState::hierarchyHypers` object, which defines the + //! parameters to use for the sampling. The default behaviour (i.e. + //! `hier_hypers = nullptr`) uses prior hyperparameters + //! @return A Protobuf message storing the state sampled from the prior model + virtual std::shared_ptr sample_proto( + ProtoHypersPtr hier_hypers = nullptr) = 0; + + //! Updates hyperparameter values given a vector of cluster states + virtual void update_hypers( + const std::vector &states) = 0; + + //! Returns a pointer to the Protobuf message of the prior of this cluster + virtual google::protobuf::Message *get_mutable_prior() = 0; + + //! Read and set hyperparameter values from a given Protobuf message + virtual void set_hypers_from_proto( + const google::protobuf::Message &state_) = 0; + + //! Writes current values of the hyperparameters to a Protobuf message by + //! pointer + virtual void write_hypers_to_proto(google::protobuf::Message *out) const = 0; + + //! Writes current value of hyperparameters to a Protobuf message and + //! return a shared_ptr. + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::HierarchyHypers message by adding the appropriate type + virtual std::shared_ptr + get_hypers_proto() const = 0; + + protected: + //! Initializes hierarchy hyperparameters to appropriate values + virtual void initialize_hypers() = 0; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_ABSTRACT_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h new file mode 100644 index 000000000..366864b4f --- /dev/null +++ b/src/hierarchies/priors/base_prior_model.h @@ -0,0 +1,199 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_BASE_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_BASE_PRIOR_MODEL_H_ + +#include + +#include +#include +#include +#include +#include + +#include "abstract_prior_model.h" +#include "algorithm_state.pb.h" +#include "hierarchy_id.pb.h" +#include "prior_model_internal.h" +#include "src/utils/rng.h" + +//! Base template class of a `PriorModel` object +//! +//! This class derives from `AbstractPriorModel` and is templated over +//! `Derived` (needed for the curiously recurring template pattern), `State` +//! (an instance of `Basestate`), `HyperParams` (a struct representing the +//! hyperparameters, see `hyperparams.h`) and `Prior`: a protobuf message +//! representing the type of prior imposed on the hyperparameters. +//! +//! @tparam Derived Name of the implemented derived class +//! @tparam State Class name of the container for state values +//! @tparam HyperParams Class name of the container for hyperparameters +//! @tparam Prior Class name of the protobuf message for the prior on the +//! hyperparameters. + +template +class BasePriorModel : public AbstractPriorModel { + public: + //! Default constructor + BasePriorModel() = default; + + //! Default destructor + ~BasePriorModel() = default; + + //! Evaluates the log likelihood for unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood of the prior model + double lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) const override { + return internal::lpdf_from_unconstrained( + static_cast(*this), unconstrained_params, 0); + }; + + //! This version using `stan::math::var` type is required for Stan automatic + //! aifferentiation. Evaluates the log likelihood for unconstrained parameter + //! values. By unconstrained parameters we mean that each entry of the + //! parameter vector can range over (-inf, inf). Usually, some kind of + //! transformation is required from the unconstrained parameterization to the + //! actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood of the prior model + stan::math::var lpdf_from_unconstrained( + Eigen::Matrix unconstrained_params) + const override { + return internal::lpdf_from_unconstrained( + static_cast(*this), unconstrained_params, 0); + }; + + virtual State sample(ProtoHypersPtr hier_hypers = nullptr) = 0; + + std::shared_ptr sample_proto( + ProtoHypersPtr hier_hypers = nullptr) override { + return sample(hier_hypers).to_proto(); + } + + //! Returns an independent, data-less copy of this object + std::shared_ptr clone() const override; + + //! Returns an independent, data-less deep copy of this object + std::shared_ptr deep_clone() const override; + + //! Returns a pointer to the Protobuf message of the prior of this cluster + google::protobuf::Message *get_mutable_prior() override; + + //! Returns the struct of the current prior hyperparameters + HyperParams get_hypers() const { return *hypers; }; + + //! Writes current values of the hyperparameters to a Protobuf message by + //! pointer + void write_hypers_to_proto(google::protobuf::Message *out) const override; + + //! Initializes the prior model (both prior and hyperparameters) + void initialize(); + + protected: + //! Raises an error if the prior pointer is not initialized + void check_prior_is_set() const; + + //! Re-initializes the prior of the hierarchy to a newly created object + void create_empty_prior() { prior.reset(new Prior); }; + + //! Re-initializes the hyperparameters of the hierarchy to a newly created + //! object + void create_empty_hypers() { hypers.reset(new HyperParams); }; + + //! Down-casts the given generic proto message to a HierarchyHypers proto + bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( + google::protobuf::Message *state_) const { + return google::protobuf::internal::down_cast< + bayesmix::AlgorithmState::HierarchyHypers *>(state_); + }; + + //! Down-casts the given generic proto message to a HierarchyHypers proto + const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( + const google::protobuf::Message &state_) const { + return google::protobuf::internal::down_cast< + const bayesmix::AlgorithmState::HierarchyHypers &>(state_); + }; + + //! Down-casts the given generic proto message to a ClusterState proto + const bayesmix::AlgorithmState::ClusterState &downcast_state( + const google::protobuf::Message &state_) const { + return google::protobuf::internal::down_cast< + const bayesmix::AlgorithmState::ClusterState &>(state_); + }; + + //! Container for prior hyperparameters values + std::shared_ptr hypers = std::make_shared(); + + //! Pointer to a Protobuf prior object for this class + std::shared_ptr prior; +}; + +/* *** Methods Definitions *** */ +template +std::shared_ptr +BasePriorModel::clone() const { + auto out = std::make_shared(static_cast(*this)); + return out; +} + +template +std::shared_ptr +BasePriorModel::deep_clone() const { + auto out = std::make_shared(static_cast(*this)); + + // Prior Deep-clone + out->create_empty_prior(); + std::shared_ptr new_prior(prior->New()); + new_prior->CopyFrom(*prior.get()); + out->get_mutable_prior()->CopyFrom(*new_prior.get()); + + // HyperParams Deep-clone + out->create_empty_hypers(); + auto curr_hypers_proto = get_hypers_proto(); + out->set_hypers_from_proto(*curr_hypers_proto.get()); + + // Initialization of Deep-cloned object + out->initialize(); + + return out; +} + +template +google::protobuf::Message * +BasePriorModel::get_mutable_prior() { + if (prior == nullptr) { + create_empty_prior(); + } + return prior.get(); +} + +template +void BasePriorModel::write_hypers_to_proto( + google::protobuf::Message *out) const { + std::shared_ptr hypers_ = + get_hypers_proto(); + auto *out_cast = downcast_hypers(out); + out_cast->CopyFrom(*hypers_.get()); +} + +template +void BasePriorModel::initialize() { + check_prior_is_set(); + create_empty_hypers(); + initialize_hypers(); +} + +template +void BasePriorModel::check_prior_is_set() + const { + if (prior == nullptr) { + throw std::invalid_argument("Hierarchy prior was not provided"); + } +} + +#endif // BAYESMIX_HIERARCHIES_PRIORS_BASE_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc new file mode 100644 index 000000000..fa402b1d1 --- /dev/null +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -0,0 +1,122 @@ +#include "fa_prior_model.h" + +double FAPriorModel::lpdf(const google::protobuf::Message &state_) { + // Downcast state + auto &state = downcast_state(state_).fa_state(); + + // Proto2Eigen conversion + Eigen::VectorXd mu = bayesmix::to_eigen(state.mu()); + Eigen::VectorXd psi = bayesmix::to_eigen(state.psi()); + Eigen::MatrixXd lambda = bayesmix::to_eigen(state.lambda()); + + // Initialize lpdf value + double target = 0.; + + // Compute lpdf + for (size_t j = 0; j < dim; j++) { + target += + stan::math::normal_lpdf(mu(j), hypers->mutilde(j), sqrt(hypers->phi)); + target += + stan::math::inv_gamma_lpdf(psi(j), hypers->alpha0, hypers->beta(j)); + for (size_t i = 0; i < hypers->q; i++) { + target += stan::math::normal_lpdf(lambda(j, i), 0, 1); + } + } + + // Return lpdf contribution + return target; +} + +State::FA FAPriorModel::sample(ProtoHypersPtr hier_hypers) { + // Random seed + auto &rng = bayesmix::Rng::Instance().get(); + + // Get params to use + auto params = get_hypers_proto()->fa_state(); + Eigen::VectorXd mutilde = bayesmix::to_eigen(params.mutilde()); + Eigen::VectorXd beta = bayesmix::to_eigen(params.beta()); + + // Compute output state + State::FA out; + out.mu = mutilde; + out.psi = beta / (params.alpha0() + 1.); + out.lambda = Eigen::MatrixXd::Zero(dim, params.q()); + for (size_t j = 0; j < dim; j++) { + out.mu[j] = stan::math::normal_rng(mutilde[j], sqrt(params.phi()), rng); + + out.psi[j] = stan::math::inv_gamma_rng(params.alpha0(), beta[j], rng); + + for (size_t i = 0; i < params.q(); i++) { + out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); + } + } + return out; +} + +void FAPriorModel::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void FAPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).fa_state(); + hypers->mutilde = bayesmix::to_eigen(hyperscast.mutilde()); + hypers->alpha0 = hyperscast.alpha0(); + hypers->beta = bayesmix::to_eigen(hyperscast.beta()); + hypers->phi = hyperscast.phi(); + hypers->q = hyperscast.q(); +} + +std::shared_ptr +FAPriorModel::get_hypers_proto() const { + bayesmix::FAPriorDistribution hypers_; + bayesmix::to_proto(hypers->mutilde, hypers_.mutable_mutilde()); + bayesmix::to_proto(hypers->beta, hypers_.mutable_beta()); + hypers_.set_alpha0(hypers->alpha0); + hypers_.set_phi(hypers->phi); + hypers_.set_q(hypers->q); + + auto out = std::make_shared(); + out->mutable_fa_state()->CopyFrom(hypers_); + return out; +} + +void FAPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mutilde = bayesmix::to_eigen(prior->fixed_values().mutilde()); + dim = hypers->mutilde.size(); + hypers->beta = bayesmix::to_eigen(prior->fixed_values().beta()); + hypers->phi = prior->fixed_values().phi(); + hypers->alpha0 = prior->fixed_values().alpha0(); + hypers->q = prior->fixed_values().q(); + + // Check validity + if (dim != hypers->beta.rows()) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + for (size_t j = 0; j < dim; j++) { + if (hypers->beta[j] <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + } + if (hypers->alpha0 <= 0) { + throw std::invalid_argument("Scale parameter must be > 0"); + } + if (hypers->phi <= 0) { + throw std::invalid_argument("Diffusion parameter must be > 0"); + } + if (hypers->q <= 0) { + throw std::invalid_argument("Number of factors must be > 0"); + } + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h new file mode 100644 index 000000000..9245f4885 --- /dev/null +++ b/src/hierarchies/priors/fa_prior_model.h @@ -0,0 +1,59 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_FA_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_FA_PRIOR_MODEL_H_ + +#include +#include +#include + +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +/** + * A priormodel for the factor analyzers likelihood, that is + * + * \f[ + * \bm{\mu} &\sim N_p(\tilde{\bm{\mu}}, \psi I) \\ + * \Lambda &\sim DL(\alpha) \\ + * \Sigma &= \mathrm{diag}(\sigma^2_1, \ldots, \sigma^2_p) \\ + * \sigma^2_j &\stackrel{\small\mathrm{iid}}{\sim} InvGamma(a,b) \quad + * j=1,...,p \f] + * + * Where \f$ DL \f$ is the Dirichlet-Laplace distribution. + * See Bhattacharya A., Pati D., Pillai N.S., Dunson D.B. (2015). + * JASA 110(512), 1479–1490 for details. + */ + +class FAPriorModel + : public BasePriorModel { + public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + + FAPriorModel() = default; + ~FAPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + State::FA sample(ProtoHypersPtr hier_hypers = nullptr) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + unsigned int get_dim() const { return dim; }; + + std::shared_ptr get_hypers_proto() + const override; + + protected: + void initialize_hypers() override; + + unsigned int dim; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_FA_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h new file mode 100644 index 000000000..1aca6dc4a --- /dev/null +++ b/src/hierarchies/priors/hyperparams.h @@ -0,0 +1,36 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_ + +#include + +namespace Hyperparams { + +struct NIG { + double mean, var_scaling, shape, scale; +}; + +struct NxIG { + double mean, var, shape, scale; +}; + +struct NW { + Eigen::VectorXd mean; + double var_scaling, deg_free; + Eigen::MatrixXd scale, scale_inv, scale_chol; +}; + +struct MNIG { + Eigen::VectorXd mean; + Eigen::MatrixXd var_scaling, var_scaling_inv; + double shape, scale; +}; + +struct FA { + Eigen::VectorXd mutilde, beta; + double phi, alpha0; + unsigned int q; +}; + +} // namespace Hyperparams + +#endif // BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_ diff --git a/src/hierarchies/priors/mnig_prior_model.cc b/src/hierarchies/priors/mnig_prior_model.cc new file mode 100644 index 000000000..2ec02169f --- /dev/null +++ b/src/hierarchies/priors/mnig_prior_model.cc @@ -0,0 +1,84 @@ +#include "mnig_prior_model.h" + +double MNIGPriorModel::lpdf(const google::protobuf::Message &state_) { + auto &state = downcast_state(state_).lin_reg_uni_ls_state(); + Eigen::VectorXd regression_coeffs = + bayesmix::to_eigen(state.regression_coeffs()); + double target = stan::math::multi_normal_prec_lpdf( + regression_coeffs, hypers->mean, hypers->var_scaling / state.var()); + target += + stan::math::inv_gamma_lpdf(state.var(), hypers->shape, hypers->scale); + return target; +} + +State::UniLinRegLS MNIGPriorModel::sample(ProtoHypersPtr hier_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + auto params = (hier_hypers) ? hier_hypers->lin_reg_uni_state() + : get_hypers_proto()->lin_reg_uni_state(); + Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); + Eigen::MatrixXd var_scaling = bayesmix::to_eigen(params.var_scaling()); + + State::UniLinRegLS out; + out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + out.regression_coeffs = + stan::math::multi_normal_prec_rng(mean, var_scaling / out.var, rng); + return out; +} + +void MNIGPriorModel::update_hypers( + const std::vector &states) { + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void MNIGPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).lin_reg_uni_state(); + hypers->mean = bayesmix::to_eigen(hyperscast.mean()); + hypers->var_scaling = bayesmix::to_eigen(hyperscast.var_scaling()); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); +} + +std::shared_ptr +MNIGPriorModel::get_hypers_proto() const { + bayesmix::MultiNormalIGDistribution hypers_; + bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); + bayesmix::to_proto(hypers->var_scaling, hypers_.mutable_var_scaling()); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); + + auto out = std::make_shared(); + out->mutable_lin_reg_uni_state()->CopyFrom(hypers_); + return out; +} + +void MNIGPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); + dim = hypers->mean.size(); + hypers->var_scaling = + bayesmix::to_eigen(prior->fixed_values().var_scaling()); + hypers->var_scaling_inv = stan::math::inverse_spd(hypers->var_scaling); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); + // Check validity + if (dim != hypers->var_scaling.rows()) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + bayesmix::check_spd(hypers->var_scaling); + if (hypers->shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers->scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h new file mode 100644 index 000000000..9a2e2ddd5 --- /dev/null +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -0,0 +1,53 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_MNIG_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_MNIG_PRIOR_MODEL_H_ + +#include +#include +#include + +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +/** + * A conjugate prior model for the scalar linear regression likelihood, i.e. + * + * \f[ + * \bm{\beta} \mid \sigma^2 & \sim N_p(\bm{\mu}, \sigma^2 \Lambda^{-1}) \\ + * \sigma^2 & \sim InvGamma(a,b) + * \f] + */ + +class MNIGPriorModel + : public BasePriorModel { + public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + + MNIGPriorModel() = default; + ~MNIGPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + State::UniLinRegLS sample(ProtoHypersPtr hier_hypers = nullptr) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + unsigned int get_dim() const { return dim; }; + + std::shared_ptr get_hypers_proto() + const override; + + protected: + void initialize_hypers() override; + + unsigned int dim; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_MNIG_PRIOR_MODEL_H_ diff --git a/src/hierarchies/nnig_hierarchy.cc b/src/hierarchies/priors/nig_prior_model.cc similarity index 62% rename from src/hierarchies/nnig_hierarchy.cc rename to src/hierarchies/priors/nig_prior_model.cc index ff1ab3870..41756e89c 100644 --- a/src/hierarchies/nnig_hierarchy.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -1,34 +1,6 @@ -#include "nnig_hierarchy.h" +#include "nig_prior_model.h" -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/rng.h" - -double NNIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); -} - -double NNIGHierarchy::marg_lpdf(const NNIG::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - double sig_n = sqrt(params.scale * (params.var_scaling + 1) / - (params.shape * params.var_scaling)); - return stan::math::student_t_lpdf(datum(0), 2 * params.shape, params.mean, - sig_n); -} - -void NNIGHierarchy::initialize_state() { - state.mean = hypers->mean; - state.var = hypers->scale / (hypers->shape + 1); -} - -void NNIGHierarchy::initialize_hypers() { +void NIGPriorModel::initialize_hypers() { if (prior->has_fixed_values()) { // Set values hypers->mean = prior->fixed_values().mean(); @@ -45,9 +17,7 @@ void NNIGHierarchy::initialize_hypers() { if (hypers->scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } - } - - else if (prior->has_normal_mean_prior()) { + } else if (prior->has_normal_mean_prior()) { // Set initial values hypers->mean = prior->normal_mean_prior().mean_prior().mean(); hypers->var_scaling = prior->normal_mean_prior().var_scaling(); @@ -63,9 +33,7 @@ void NNIGHierarchy::initialize_hypers() { if (hypers->scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } - } - - else if (prior->has_ngg_prior()) { + } else if (prior->has_ngg_prior()) { // Get hyperparameters: // for mu0 double mu00 = prior->ngg_prior().mean_prior().mean(); @@ -102,22 +70,39 @@ void NNIGHierarchy::initialize_hypers() { hypers->var_scaling = alpha00 / beta00; hypers->shape = alpha0; hypers->scale = a00 / b00; - } - - else { + } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } } -void NNIGHierarchy::update_hypers( +double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { + auto &state = downcast_state(state_).uni_ls_state(); + double target = + stan::math::normal_lpdf(state.mean(), hypers->mean, + sqrt(state.var() / hypers->var_scaling)) + + stan::math::inv_gamma_lpdf(state.var(), hypers->shape, hypers->scale); + return target; +} + +State::UniLS NIGPriorModel::sample(ProtoHypersPtr hier_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + auto params = (hier_hypers) ? hier_hypers->nnig_state() + : get_hypers_proto()->nnig_state(); + + State::UniLS out; + out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + out.mean = stan::math::normal_rng(params.mean(), + sqrt(out.var / params.var_scaling()), rng); + return out; +} + +void NIGPriorModel::update_hypers( const std::vector &states) { auto &rng = bayesmix::Rng::Instance().get(); if (prior->has_fixed_values()) { return; - } - - else if (prior->has_normal_mean_prior()) { + } else if (prior->has_normal_mean_prior()) { // Get hyperparameters double mu00 = prior->normal_mean_prior().mean_prior().mean(); double sig200 = prior->normal_mean_prior().mean_prior().var(); @@ -137,9 +122,7 @@ void NNIGHierarchy::update_hypers( double sig2_n = 1 / prec; // Update hyperparameters with posterior random sampling hypers->mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); - } - - else if (prior->has_ngg_prior()) { + } else if (prior->has_ngg_prior()) { // Get hyperparameters: // for mu0 double mu00 = prior->ngg_prior().mean_prior().mean(); @@ -173,78 +156,12 @@ void NNIGHierarchy::update_hypers( hypers->mean = stan::math::normal_rng(mu_n, sig_n, rng); hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); hypers->scale = stan::math::gamma_rng(a_n, b_n, rng); - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -NNIG::State NNIGHierarchy::draw(const NNIG::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - NNIG::State out; - out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - out.mean = stan::math::normal_rng(params.mean, - sqrt(state.var / params.var_scaling), rng); - return out; -} - -void NNIGHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) { - if (add) { - data_sum += datum(0); - data_sum_squares += datum(0) * datum(0); } else { - data_sum -= datum(0); - data_sum_squares -= datum(0) * datum(0); - } -} - -void NNIGHierarchy::clear_summary_statistics() { - data_sum = 0; - data_sum_squares = 0; -} - -NNIG::Hyperparams NNIGHierarchy::compute_posterior_hypers() const { - // Initialize relevant variables - if (card == 0) { // no update possible - return *hypers; + throw std::invalid_argument("Unrecognized hierarchy prior"); } - // Compute posterior hyperparameters - NNIG::Hyperparams post_params; - double y_bar = data_sum / (1.0 * card); // sample mean - double ss = data_sum_squares - card * y_bar * y_bar; - post_params.mean = (hypers->var_scaling * hypers->mean + data_sum) / - (hypers->var_scaling + card); - post_params.var_scaling = hypers->var_scaling + card; - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = hypers->scale + 0.5 * ss + - 0.5 * hypers->var_scaling * card * - (y_bar - hypers->mean) * (y_bar - hypers->mean) / - (card + hypers->var_scaling); - return post_params; -} - -void NNIGHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.var = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -NNIGHierarchy::get_state_proto() const { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.var); - - auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); - return out; } -void NNIGHierarchy::set_hypers_from_proto( +void NIGPriorModel::set_hypers_from_proto( const google::protobuf::Message &hypers_) { auto &hyperscast = downcast_hypers(hypers_).nnig_state(); hypers->mean = hyperscast.mean(); @@ -254,7 +171,7 @@ void NNIGHierarchy::set_hypers_from_proto( } std::shared_ptr -NNIGHierarchy::get_hypers_proto() const { +NIGPriorModel::get_hypers_proto() const { bayesmix::NIGDistribution hypers_; hypers_.set_mean(hypers->mean); hypers_.set_var_scaling(hypers->var_scaling); diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h new file mode 100644 index 000000000..43a049ea8 --- /dev/null +++ b/src/hierarchies/priors/nig_prior_model.h @@ -0,0 +1,70 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_NIG_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_NIG_PRIOR_MODEL_H_ + +#include +#include +#include +#include + +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +/** + * A conjugate prior model for the univariate normal likelihood, that is + * + * \f[ + * \mu \mid \sigma^2 &\sim N(\mu_0, \sigma^2 / \lambda) \\ + * \sigma^2 &\sim InvGamma(a,b) + * \f] + * + * With several possibilies for hyper-priors on \f$ \mu \f$ and \f$ \sigma^2 + * \f$. We have considered a normal prior for \f$ mu0 \f$ and a + * Normal-Gamma-Gamma for \f$ (mu0, a, b) \f$ in addition to fixing prior + * hyperparameters + */ + +class NIGPriorModel + : public BasePriorModel { + public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + + NIGPriorModel() = default; + ~NIGPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + template + T lpdf_from_unconstrained( + const Eigen::Matrix &unconstrained_params) const { + Eigen::Matrix constrained_params = + State::uni_ls_to_constrained(unconstrained_params); + T log_det_jac = State::uni_ls_log_det_jac(constrained_params); + T mean = constrained_params(0); + T var = constrained_params(1); + T lpdf = stan::math::normal_lpdf(mean, hypers->mean, + sqrt(var / hypers->var_scaling)) + + stan::math::inv_gamma_lpdf(var, hypers->shape, hypers->scale); + + return lpdf + log_det_jac; + } + + State::UniLS sample(ProtoHypersPtr hier_hypers = nullptr) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + std::shared_ptr get_hypers_proto() + const override; + + protected: + void initialize_hypers() override; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_NIG_PRIOR_MODEL_H_ diff --git a/src/hierarchies/nnw_hierarchy.cc b/src/hierarchies/priors/nw_prior_model.cc similarity index 56% rename from src/hierarchies/nnw_hierarchy.cc rename to src/hierarchies/priors/nw_prior_model.cc index 65038e946..46c9f6c77 100644 --- a/src/hierarchies/nnw_hierarchy.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -1,75 +1,10 @@ -#include "nnw_hierarchy.h" +#include "nw_prior_model.h" -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "matrix.pb.h" #include "src/utils/distributions.h" #include "src/utils/eigen_utils.h" #include "src/utils/proto_utils.h" -#include "src/utils/rng.h" - -double NNWHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return bayesmix::multi_normal_prec_lpdf(datum, state.mean, state.prec_chol, - state.prec_logdet); -} - -double NNWHierarchy::marg_lpdf(const NNW::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - NNW::Hyperparams pred_params = get_predictive_t_parameters(params); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf( - datum, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} - -Eigen::VectorXd NNWHierarchy::like_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - return bayesmix::multi_normal_prec_lpdf_grid( - data, state.mean, state.prec_chol, state.prec_logdet); -} - -Eigen::VectorXd NNWHierarchy::prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - NNW::Hyperparams pred_params = get_predictive_t_parameters(*hypers); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf_grid( - data, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} - -Eigen::VectorXd NNWHierarchy::conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - NNW::Hyperparams pred_params = - get_predictive_t_parameters(compute_posterior_hypers()); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf_grid( - data, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} -void NNWHierarchy::initialize_state() { - state.mean = hypers->mean; - write_prec_to_state( - hypers->var_scaling * Eigen::MatrixXd::Identity(dim, dim), &state); -} - -void NNWHierarchy::initialize_hypers() { +void NWPriorModel::initialize_hypers() { if (prior->has_fixed_values()) { // Set values hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); @@ -102,7 +37,7 @@ void NNWHierarchy::initialize_hypers() { bayesmix::to_eigen(prior->normal_mean_prior().scale()); double nu0 = prior->normal_mean_prior().deg_free(); // Check validity - unsigned int dim = mu00.size(); + dim = mu00.size(); if (sigma00.rows() != dim or tau0.rows() != dim) { throw std::invalid_argument( "Hyperparameters dimensions are not consistent"); @@ -175,9 +110,39 @@ void NNWHierarchy::initialize_hypers() { hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); } -void NNWHierarchy::update_hypers( +double NWPriorModel::lpdf(const google::protobuf::Message &state_) { + auto &state = downcast_state(state_).multi_ls_state(); + Eigen::VectorXd mean = bayesmix::to_eigen(state.mean()); + Eigen::MatrixXd prec = bayesmix::to_eigen(state.prec()); + double target = + stan::math::multi_normal_prec_lpdf(mean, hypers->mean, + prec * hypers->var_scaling) + + stan::math::wishart_lpdf(prec, hypers->deg_free, hypers->scale); + return target; +} + +State::MultiLS NWPriorModel::sample(ProtoHypersPtr hier_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + auto params = (hier_hypers) ? hier_hypers->nnw_state() + : get_hypers_proto()->nnw_state(); + Eigen::MatrixXd scale = bayesmix::to_eigen(params.scale()); + Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); + + Eigen::MatrixXd tau_new = + stan::math::wishart_rng(params.deg_free(), scale, rng); + + // Update state + State::MultiLS out; + out.mean = stan::math::multi_normal_prec_rng( + mean, tau_new * params.var_scaling(), rng); + write_prec_to_state(tau_new, &out); + return out; +}; + +void NWPriorModel::update_hypers( const std::vector &states) { auto &rng = bayesmix::Rng::Instance().get(); + if (prior->has_fixed_values()) { return; } @@ -253,121 +218,39 @@ void NNWHierarchy::update_hypers( } } -NNW::State NNWHierarchy::draw(const NNW::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - Eigen::MatrixXd tau_new = - stan::math::wishart_rng(params.deg_free, params.scale, rng); - // Update state - NNW::State out; - out.mean = stan::math::multi_normal_prec_rng( - params.mean, tau_new * params.var_scaling, rng); - write_prec_to_state(tau_new, &out); - return out; -} - -void NNWHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) { - if (add) { - data_sum += datum.transpose(); - data_sum_squares += datum.transpose() * datum; - } else { - data_sum -= datum.transpose(); - data_sum_squares -= datum.transpose() * datum; - } -} - -void NNWHierarchy::clear_summary_statistics() { - data_sum = Eigen::VectorXd::Zero(dim); - data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); -} - -NNW::Hyperparams NNWHierarchy::compute_posterior_hypers() const { - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - NNW::Hyperparams post_params; - post_params.var_scaling = hypers->var_scaling + card; - post_params.deg_free = hypers->deg_free + card; - Eigen::VectorXd mubar = data_sum.array() / card; // sample mean - post_params.mean = (hypers->var_scaling * hypers->mean + card * mubar) / - (hypers->var_scaling + card); - // Compute tau_n - Eigen::MatrixXd tau_temp = - data_sum_squares - card * mubar * mubar.transpose(); - tau_temp += (card * hypers->var_scaling / (card + hypers->var_scaling)) * - (mubar - hypers->mean) * (mubar - hypers->mean).transpose(); - post_params.scale_inv = tau_temp + hypers->scale_inv; - post_params.scale = stan::math::inverse_spd(post_params.scale_inv); - post_params.scale_chol = - Eigen::LLT(post_params.scale).matrixU(); - return post_params; -} - -void NNWHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = to_eigen(statecast.multi_ls_state().mean()); - state.prec = to_eigen(statecast.multi_ls_state().prec()); - state.prec_chol = to_eigen(statecast.multi_ls_state().prec_chol()); - Eigen::VectorXd diag = state.prec_chol.diagonal(); - state.prec_logdet = 2 * log(diag.array()).sum(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -NNWHierarchy::get_state_proto() const { - bayesmix::MultiLSState state_; - bayesmix::to_proto(state.mean, state_.mutable_mean()); - bayesmix::to_proto(state.prec, state_.mutable_prec()); - bayesmix::to_proto(state.prec_chol, state_.mutable_prec_chol()); - - auto out = std::make_shared(); - out->mutable_multi_ls_state()->CopyFrom(state_); - return out; -} - -void NNWHierarchy::set_hypers_from_proto( +void NWPriorModel::set_hypers_from_proto( const google::protobuf::Message &hypers_) { auto &hyperscast = downcast_hypers(hypers_).nnw_state(); - hypers->mean = to_eigen(hyperscast.mean()); + hypers->mean = bayesmix::to_eigen(hyperscast.mean()); hypers->var_scaling = hyperscast.var_scaling(); hypers->deg_free = hyperscast.deg_free(); - hypers->scale = to_eigen(hyperscast.scale()); + hypers->scale = bayesmix::to_eigen(hyperscast.scale()); + hypers->scale_inv = stan::math::inverse_spd(hypers->scale); + hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); } std::shared_ptr -NNWHierarchy::get_hypers_proto() const { - bayesmix::NWDistribution hypers_; - bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); - hypers_.set_var_scaling(hypers->var_scaling); - hypers_.set_deg_free(hypers->deg_free); - bayesmix::to_proto(hypers->scale, hypers_.mutable_scale()); - +NWPriorModel::get_hypers_proto() const { + // Translate to proto + bayesmix::Vector mean_proto; + bayesmix::Matrix scale_proto; + bayesmix::to_proto(hypers->mean, &mean_proto); + bayesmix::to_proto(hypers->scale, &scale_proto); + + // Make output state and return auto out = std::make_shared(); - out->mutable_nnw_state()->CopyFrom(hypers_); + out->mutable_nnw_state()->mutable_mean()->CopyFrom(mean_proto); + out->mutable_nnw_state()->set_var_scaling(hypers->var_scaling); + out->mutable_nnw_state()->set_deg_free(hypers->deg_free); + out->mutable_nnw_state()->mutable_scale()->CopyFrom(scale_proto); return out; } -void NNWHierarchy::write_prec_to_state(const Eigen::MatrixXd &prec_, - NNW::State *out) { +void NWPriorModel::write_prec_to_state(const Eigen::MatrixXd &prec_, + State::MultiLS *out) { out->prec = prec_; // Update prec utilities out->prec_chol = Eigen::LLT(prec_).matrixU(); Eigen::VectorXd diag = out->prec_chol.diagonal(); out->prec_logdet = 2 * log(diag.array()).sum(); } - -NNW::Hyperparams NNWHierarchy::get_predictive_t_parameters( - const NNW::Hyperparams ¶ms) const { - // Compute dof and scale of marginal distribution - double nu_n = params.deg_free - dim + 1; - double coeff = (params.var_scaling + 1) / (params.var_scaling * nu_n); - Eigen::MatrixXd scale_chol_n = params.scale_chol / std::sqrt(coeff); - - NNW::Hyperparams out; - out.mean = params.mean; - out.deg_free = nu_n; - out.scale_chol = scale_chol_n; - return out; -} diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h new file mode 100644 index 000000000..1a6394bd3 --- /dev/null +++ b/src/hierarchies/priors/nw_prior_model.h @@ -0,0 +1,57 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_NW_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_NW_PRIOR_MODEL_H_ + +#include +#include +#include +#include + +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +/** + * A conjugate prior model for the multivariate normal likelihood, that is + * + * \f[ + * \bm{\mu} \mid \Sigma &\sim N_p(\bm{\mu}_0, (\Sigma \lambda)^{-1}) \\ + * \Sigma & \sim Wishart(\nu_0, \Psi_0) + * \f] + * + * With some options for hyper-priors on \f$ \bm{\mu} \f$ and \f$ \Sigma \f$. + * We have considered a normal prior for \f$ \bm{\mu}_0 \f$ in addition to + * fixing prior hyperparameters + */ + +class NWPriorModel + : public BasePriorModel { + public: + NWPriorModel() = default; + ~NWPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + State::MultiLS sample(ProtoHypersPtr hier_hypers = nullptr) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + void write_prec_to_state(const Eigen::MatrixXd &prec_, State::MultiLS *out); + + unsigned int get_dim() const { return dim; }; + + std::shared_ptr get_hypers_proto() + const override; + + protected: + void initialize_hypers() override; + + unsigned int dim; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_NW_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/nxig_prior_model.cc b/src/hierarchies/priors/nxig_prior_model.cc new file mode 100644 index 000000000..0b7b0cbea --- /dev/null +++ b/src/hierarchies/priors/nxig_prior_model.cc @@ -0,0 +1,73 @@ +#include "nxig_prior_model.h" + +void NxIGPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mean = prior->fixed_values().mean(); + hypers->var = prior->fixed_values().var(); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); + // Check validity + if (hypers->var <= 0) { + throw std::invalid_argument("Variance parameter must be > 0"); + } + if (hypers->shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers->scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +double NxIGPriorModel::lpdf(const google::protobuf::Message &state_) { + auto &state = downcast_state(state_).uni_ls_state(); + double target = + stan::math::normal_lpdf(state.mean(), hypers->mean, sqrt(hypers->var)) + + stan::math::inv_gamma_lpdf(state.var(), hypers->shape, hypers->scale); + return target; +} + +State::UniLS NxIGPriorModel::sample(ProtoHypersPtr hier_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + auto params = (hier_hypers) ? hier_hypers->nnxig_state() + : get_hypers_proto()->nnxig_state(); + + State::UniLS out; + out.mean = stan::math::normal_rng(params.mean(), sqrt(params.var()), rng); + out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + return out; +}; + +void NxIGPriorModel::update_hypers( + const std::vector &states) { + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NxIGPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).nnxig_state(); + hypers->mean = hyperscast.mean(); + hypers->var = hyperscast.var(); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); +} + +std::shared_ptr +NxIGPriorModel::get_hypers_proto() const { + bayesmix::NxIGDistribution hypers_; + hypers_.set_mean(hypers->mean); + hypers_.set_var(hypers->var); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); + + auto out = std::make_shared(); + out->mutable_nnxig_state()->CopyFrom(hypers_); + return out; +} diff --git a/src/hierarchies/priors/nxig_prior_model.h b/src/hierarchies/priors/nxig_prior_model.h new file mode 100644 index 000000000..b77cfe16d --- /dev/null +++ b/src/hierarchies/priors/nxig_prior_model.h @@ -0,0 +1,50 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_NXIG_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_NXIG_PRIOR_MODEL_H_ + +#include +#include +#include +#include + +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +/** + * A semi-conjugate prior model for the univariate normal likelihood, that is + * + * \f[ + * \mu & \sim N(\mu_0, \eta^2) \\ + * \sigma^2 & \sim InvGamma(a,b) + * \f] + */ + +class NxIGPriorModel + : public BasePriorModel { + public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + + NxIGPriorModel() = default; + ~NxIGPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + State::UniLS sample(ProtoHypersPtr hier_hypers = nullptr) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + std::shared_ptr get_hypers_proto() + const override; + + protected: + void initialize_hypers() override; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_NXIG_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/prior_model_internal.h b/src/hierarchies/priors/prior_model_internal.h new file mode 100644 index 000000000..e1de1f5ff --- /dev/null +++ b/src/hierarchies/priors/prior_model_internal.h @@ -0,0 +1,32 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_PRIOR_MODEL_INTERNAL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_PRIOR_MODEL_INTERNAL_H_ + +//! These functions exploit SFINAE to manage exception handling in all methods +//! required only if end user wants to rely on Metropolis-like updaters. SFINAE +//! (Substitution Failure Is Not An Error) is a C++ rule that applies during +//! overload resolution of function templates: When substituting the explicitly +//! specified or deduced type for the template parameter fails, the +//! specialization is discarded from the overload set instead of causing a +//! compile error. This feature is used in template metaprogramming. + +namespace internal { + +template +auto lpdf_from_unconstrained( + const Prior &prior, + Eigen::Matrix unconstrained_params, int) + -> decltype(prior.template lpdf_from_unconstrained( + unconstrained_params)) { + return prior.template lpdf_from_unconstrained(unconstrained_params); +} + +template +auto lpdf_from_unconstrained( + const Prior &prior, + Eigen::Matrix unconstrained_params, double) -> T { + throw(std::runtime_error("lpdf_from_unconstrained() not yet implemented")); +} + +} // namespace internal + +#endif // BAYESMIX_HIERARCHIES_PRIORS_PRIOR_MODEL_INTERNAL_H_ diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt new file mode 100644 index 000000000..dbc5e7e9a --- /dev/null +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -0,0 +1,18 @@ +target_sources(bayesmix PUBLIC + abstract_updater.h + semi_conjugate_updater.h + nnig_updater.h + nnig_updater.cc + nnxig_updater.h + nnxig_updater.cc + nnw_updater.h + nnw_updater.cc + mnig_updater.h + mnig_updater.cc + fa_updater.h + fa_updater.cc + metropolis_updater.h + mala_updater.h + random_walk_updater.h + target_lpdf_unconstrained.h +) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h new file mode 100644 index 000000000..eaa19ca58 --- /dev/null +++ b/src/hierarchies/updaters/abstract_updater.h @@ -0,0 +1,61 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ + +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" + +//! Base class for an Updater + +//! An updater is a class able to sample from the full conditional +//! distribution of an `Hierarchy`, coming from the product of a +//! `Likelihood` and a `Prior`, possibly using a Metropolis-Hastings +//! algorithm. +class AbstractUpdater { + public: + // Type aliases + using ProtoHypersPtr = + std::shared_ptr; + using ProtoHypers = ProtoHypersPtr::element_type; + + //! Default destructor + virtual ~AbstractUpdater() = default; + + //! Returns whether the current updater is for a (semi)conjugate model or not + virtual bool is_conjugate() const { return false; }; + + //! Sampling from the full conditional, given the likelihood and the prior + //! model that constitutes the hierarchy + //! @param like The likelihood of the hierarchy + //! @param prior The prior model of the hierarchy + //! @param update_params Save posterior hyperparameters after draw? + virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + bool update_params) = 0; + + //! Computes the posterior hyperparameters required for the sampling in case + //! of conjugate hierarchies + virtual ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) { + if (!is_conjugate()) { + throw( + std::runtime_error("Cannot call compute_posterior_hypers() from a " + "non-(semi)conjugate updater")); + } else { + throw(std::runtime_error( + "compute_posterior_hypers() not implemented for this updater")); + } + } + + //! Stores the posterior hyperparameters in an appropriate container + virtual void save_posterior_hypers(ProtoHypersPtr post_hypers_) { + if (!is_conjugate()) { + throw( + std::runtime_error("Cannot call save_posterior_hypers() from a " + "non-(semi)conjugate updater")); + } else { + throw(std::runtime_error( + "save_posterior_hypers() not implemented for this updater")); + } + } +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ diff --git a/src/hierarchies/updaters/fa_updater.cc b/src/hierarchies/updaters/fa_updater.cc new file mode 100644 index 000000000..24d3408a8 --- /dev/null +++ b/src/hierarchies/updaters/fa_updater.cc @@ -0,0 +1,125 @@ +#include "fa_updater.h" + +#include "src/utils/distributions.h" + +void FAUpdater::draw(AbstractLikelihood& like, AbstractPriorModel& prior, + bool update_params) { + // Likelihood and PriorModel downcast + auto& likecast = static_cast(like); + auto& priorcast = static_cast(prior); + // Sample from the full conditional of the fa hierarchy + bool set_card = true, use_post_hypers = true; + if (likecast.get_card() == 0) { + likecast.set_state(priorcast.sample(), !set_card); + } else { + // Get state and hypers + State::FA new_state = likecast.get_state(); + Hyperparams::FA hypers = priorcast.get_hypers(); + // Gibbs update + sample_eta(new_state, hypers, likecast); + sample_mu(new_state, hypers, likecast); + sample_psi(new_state, hypers, likecast); + sample_lambda(new_state, hypers, likecast); + likecast.set_state(new_state, !set_card); + } +} + +void FAUpdater::sample_eta(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like) { + // Random Seed + auto& rng = bayesmix::Rng::Instance().get(); + // Get required information + auto dataset_ptr = like.get_dataset(); + auto cluster_data_idx = like.get_data_idx(); + unsigned int card = like.get_card(); + // eta update + state.eta = Eigen::MatrixXd::Zero(card, hypers.q); + auto sigma_eta_inv_llt = + (Eigen::MatrixXd::Identity(hypers.q, hypers.q) + + state.lambda.transpose() * state.psi_inverse * state.lambda) + .llt(); + Eigen::MatrixXd temp_product( + sigma_eta_inv_llt.solve(state.lambda.transpose() * state.psi_inverse)); + auto iterator = cluster_data_idx.begin(); + for (size_t i = 0; i < card; i++, iterator++) { + Eigen::VectorXd tempvector(dataset_ptr->row( + *iterator)); // TODO use slicing when Eigen is updated to v3.4 + state.eta.row(i) = (bayesmix::multi_normal_prec_chol_rng( + temp_product * (tempvector - state.mu), sigma_eta_inv_llt, rng)); + } +} + +void FAUpdater::sample_mu(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like) { + // Random seed + auto& rng = bayesmix::Rng::Instance().get(); + // Get required information + Eigen::VectorXd data_sum = like.get_data_sum(); + unsigned int card = like.get_card(); + // mu update + Eigen::DiagonalMatrix sigma_mu; + sigma_mu.diagonal() = + (card * state.psi_inverse.diagonal().array() + hypers.phi) + .cwiseInverse(); + Eigen::VectorXd sum = (state.eta.colwise().sum()); + Eigen::VectorXd mumean = + sigma_mu * (hypers.phi * hypers.mutilde + + state.psi_inverse * (data_sum - state.lambda * sum)); + state.mu = bayesmix::multi_normal_diag_rng(mumean, sigma_mu, rng); +} + +void FAUpdater::sample_lambda(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like) { + // Random seed + auto& rng = bayesmix::Rng::Instance().get(); + // Getting required information + unsigned int dim = like.get_dim(); + unsigned int card = like.get_card(); + auto dataset_ptr = like.get_dataset(); + auto cluster_data_idx = like.get_data_idx(); + // lambda update + Eigen::MatrixXd temp_etateta(state.eta.transpose() * state.eta); + for (size_t j = 0; j < dim; j++) { + auto sigma_lambda_inv_llt = + (Eigen::MatrixXd::Identity(hypers.q, hypers.q) + + temp_etateta / state.psi[j]) + .llt(); + Eigen::VectorXd tempsum(card); + const Eigen::VectorXd& data_col = dataset_ptr->col(j); + auto iterator = cluster_data_idx.begin(); + for (size_t i = 0; i < card; i++, iterator++) { + tempsum[i] = data_col( + *iterator); // TODO use slicing when Eigen is updated to v3.4 + } + tempsum = tempsum.array() - state.mu[j]; + tempsum = tempsum.array() / state.psi[j]; + state.lambda.row(j) = bayesmix::multi_normal_prec_chol_rng( + sigma_lambda_inv_llt.solve(state.eta.transpose() * tempsum), + sigma_lambda_inv_llt, rng); + } +} + +void FAUpdater::sample_psi(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like) { + // Random seed + auto& rng = bayesmix::Rng::Instance().get(); + // Getting required information + auto dataset_ptr = like.get_dataset(); + auto cluster_data_idx = like.get_data_idx(); + unsigned int dim = like.get_dim(); + unsigned int card = like.get_card(); + // psi update + for (size_t j = 0; j < dim; j++) { + double sum = 0; + auto iterator = cluster_data_idx.begin(); + for (size_t i = 0; i < card; i++, iterator++) { + sum += std::pow( + ((*dataset_ptr)(*iterator, j) - + state.mu[j] - // TODO use slicing when Eigen is updated to v3.4 + state.lambda.row(j).dot(state.eta.row(i))), + 2); + } + state.psi[j] = stan::math::inv_gamma_rng(hypers.alpha0 + card / 2, + hypers.beta[j] + sum / 2, rng); + } +} diff --git a/src/hierarchies/updaters/fa_updater.h b/src/hierarchies/updaters/fa_updater.h new file mode 100644 index 000000000..b02bdbfca --- /dev/null +++ b/src/hierarchies/updaters/fa_updater.h @@ -0,0 +1,33 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_FA_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_FA_UPDATER_H_ + +#include "abstract_updater.h" +#include "src/hierarchies/likelihoods/fa_likelihood.h" +#include "src/hierarchies/likelihoods/states/includes.h" +#include "src/hierarchies/priors/fa_prior_model.h" +#include "src/hierarchies/priors/hyperparams.h" +#include "src/utils/proto_utils.h" + +//! Updater specific for the `FAHierachy`. +//! See Bhattacharya, Anirban, and David B. Dunson. +//! "Sparse Bayesian infinite factor models." Biometrika (2011): 291-306. +//! for further details +class FAUpdater : public AbstractUpdater { + public: + FAUpdater() = default; + ~FAUpdater() = default; + void draw(AbstractLikelihood& like, AbstractPriorModel& prior, + bool update_params) override; + + protected: + void sample_eta(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like); + void sample_mu(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like); + void sample_lambda(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like); + void sample_psi(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like); +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_FA_UPDATER_H_ diff --git a/src/hierarchies/updaters/mala_updater.h b/src/hierarchies/updaters/mala_updater.h new file mode 100644 index 000000000..5f012b558 --- /dev/null +++ b/src/hierarchies/updaters/mala_updater.h @@ -0,0 +1,92 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_MALA_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_MALA_UPDATER_H_ + +#include + +#include "metropolis_updater.h" + +/** + * Metropolis Adjusted Langevin Algorithm. + * + * This class requires that the Hierarchy's state implements + * the `get_unconstrained()`, `set_from_unconstrained()` and + * `log_det_jac()` functions. + * + * Given the current value of the unconstrained parameters \f$ x \f$, a new + * value is proposed from + * + * \f[ + * x_{new} \sim N(x + step\_size \cdot \text{grad}(full\_cond)(x), \sqrt{2 + * step\_size} \cdot I) \f] + * + * and then either accepted (in which case the hierarchy's state is + * set to \f$ x_{new} \f$) or rejected. + */ + +class MalaUpdater : public MetropolisUpdater { + protected: + double step_size; + + public: + MalaUpdater() = default; + ~MalaUpdater() = default; + + MalaUpdater(double step_size) : step_size(step_size) {} + + //! Samples from the proposal distribution + //! @param curr_state the current state (unconstrained parametrization) + //! @param like instance of likelihood + //! @param prior instance of prior + //! @param target_lpdf either double or stan::math::var. Needed for + //! stan's automatic differentiation. It will be + //! filled with the lpdf at the 'curr_state' + Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, + AbstractLikelihood &like, + AbstractPriorModel &prior, + target_lpdf_unconstrained &target_lpdf) { + Eigen::VectorXd noise(curr_state.size()); + auto &rng = bayesmix::Rng::Instance().get(); + double noise_scale = std::sqrt(2 * step_size); + for (int i = 0; i < curr_state.size(); i++) { + noise(i) = stan::math::normal_rng(0, noise_scale, rng); + } + Eigen::VectorXd grad; + double tmp; + stan::math::gradient(target_lpdf, curr_state, tmp, grad); + return curr_state + step_size * grad + noise; + } + + //! Evaluates the log probability density function of the proposal + //! @param prop_state the proposed state (at which to evaluate the lpdf) + //! @param curr_state the current state (unconstrained parametrization) + //! @param like instance of likelihood + //! @param prior instance of prior + //! @param target_lpdf either double or stan::math::var. Needed for + //! stan's automatic differentiation. It will be + //! filled with the lpdf at 'curr_state' + double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, + AbstractLikelihood &like, AbstractPriorModel &prior, + target_lpdf_unconstrained &target_lpdf) { + double out; + Eigen::VectorXd grad; + double tmp; + stan::math::gradient(target_lpdf, curr_state, tmp, grad); + Eigen::VectorXd mean = curr_state + step_size * grad; + + double noise_scale = std::sqrt(2 * step_size); + + for (int i = 0; i < prop_state.size(); i++) { + out += stan::math::normal_lpdf(prop_state(i), mean(i), noise_scale); + } + return out; + } + + //! Returns a shared_ptr to a new instance of `this` + std::shared_ptr clone() const { + auto out = + std::make_shared(static_cast(*this)); + return out; + } +}; + +#endif diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h new file mode 100644 index 000000000..e4bb93a42 --- /dev/null +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -0,0 +1,61 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_METROPOLIS_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_METROPOLIS_UPDATER_H_ + +#include "abstract_updater.h" +#include "target_lpdf_unconstrained.h" + +//! Base class for updaters using a Metropolis-Hastings algorithm +//! +//! This class serves as the base for a CRTP. +//! Children of this class should implement the methods +//! template +//! Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, +//! AbstractLikelihood &like, +//! AbstractPriorModel &prior, F +//! &target_lpdf) +//! and +//! template +//! double proposal_lpdf(Eigen::VectorXd prop_state, +//! Eigen::VectorXd curr_state, +//! AbstractLikelihood &like, +//! AbstractPriorModel &prior, +//! F &target_lpdf) +//! where the template parameter is needed to allow the use of stan's +//! automatic differentiation if the gradient of the full conditional is +//! required. +template +class MetropolisUpdater : public AbstractUpdater { + public: + //! Samples from the full conditional distribution using a + //! Metropolis-Hastings step + void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + bool update_params) override { + if (update_params) { + throw std::runtime_error( + "'update_params' can be True only when using instances of" + " 'SemiConjugateUpdater'. This is likely caused by" + " using a nonconjugate hierarchy (or a nonconjugate updater)" + " in a marginal algorithm such as 'Neal3'."); + } + + target_lpdf_unconstrained target_lpdf(&like, &prior); + Eigen::VectorXd curr_state = like.get_unconstrained_state(); + Eigen::VectorXd prop_state = + static_cast(this)->sample_proposal( + curr_state, like, prior, target_lpdf); + + double log_arate = like.cluster_lpdf_from_unconstrained(prop_state) - + like.cluster_lpdf_from_unconstrained(curr_state) + + static_cast(this)->proposal_lpdf( + curr_state, prop_state, like, prior, target_lpdf) - + static_cast(this)->proposal_lpdf( + prop_state, curr_state, like, prior, target_lpdf); + + auto &rng = bayesmix::Rng::Instance().get(); + if (std::log(stan::math::uniform_rng(0, 1, rng)) < log_arate) { + like.set_state_from_unconstrained(prop_state); + } + } +}; + +#endif diff --git a/src/hierarchies/updaters/mnig_updater.cc b/src/hierarchies/updaters/mnig_updater.cc new file mode 100644 index 000000000..beb822c1a --- /dev/null +++ b/src/hierarchies/updaters/mnig_updater.cc @@ -0,0 +1,44 @@ +#include "mnig_updater.h" + +AbstractUpdater::ProtoHypersPtr MNIGUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + int card = likecast.get_card(); + unsigned int dim = likecast.get_dim(); + double data_sum_squares = likecast.get_data_sum_squares(); + Eigen::MatrixXd covar_sum_squares = likecast.get_covar_sum_squares(); + Eigen::MatrixXd mixed_prod = likecast.get_mixed_prod(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + return priorcast.get_hypers_proto(); + } + + // Compute posterior hyperparameters + Eigen::VectorXd mean; + Eigen::MatrixXd var_scaling, var_scaling_inv; + double shape, scale; + + var_scaling = covar_sum_squares + hypers.var_scaling; + auto llt = var_scaling.llt(); + mean = llt.solve(mixed_prod + hypers.var_scaling * hypers.mean); + shape = hypers.shape + 0.5 * card; + scale = hypers.scale + + 0.5 * (data_sum_squares + + hypers.mean.transpose() * hypers.var_scaling * hypers.mean - + mean.transpose() * var_scaling * mean); + + // Proto conversion + ProtoHypers out; + bayesmix::to_proto(mean, out.mutable_lin_reg_uni_state()->mutable_mean()); + bayesmix::to_proto(var_scaling, + out.mutable_lin_reg_uni_state()->mutable_var_scaling()); + out.mutable_lin_reg_uni_state()->set_shape(shape); + out.mutable_lin_reg_uni_state()->set_scale(scale); + return std::make_shared(out); +} diff --git a/src/hierarchies/updaters/mnig_updater.h b/src/hierarchies/updaters/mnig_updater.h new file mode 100644 index 000000000..ec7a8c65f --- /dev/null +++ b/src/hierarchies/updaters/mnig_updater.h @@ -0,0 +1,36 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ + +#include "semi_conjugate_updater.h" +#include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h" +#include "src/hierarchies/priors/mnig_prior_model.h" + +/** + * Updater specific for the `UniLinRegLikelihood` used in combination + * with `MNIGPriorModel`, that is the model + * + * \f[ + * y_i \mid \bm{\beta}, \sigma^2 &\stackrel{\small\mathrm{iid}}{\sim} + * N(\bm{\beta}^T\bm{x}_i, \sigma^2) \\ + * \bm{\beta} \mid \sigma^2 &\sim N_p(\mu_{0}, \sigma^2 \mathbf{V}^{-1}) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * It exploits the conjugacy of the model to sample the full conditional of + * \f$ (\bm{\beta}, \sigma^2) \f$ by calling `MNIGPriorModel::sample` with + * updated parameters + */ + +class MNIGUpdater + : public SemiConjugateUpdater { + public: + MNIGUpdater() = default; + ~MNIGUpdater() = default; + + bool is_conjugate() const override { return true; }; + + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc new file mode 100644 index 000000000..c73ae39c1 --- /dev/null +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -0,0 +1,42 @@ +#include "nnig_updater.h" + +#include "src/hierarchies/likelihoods/states/includes.h" +#include "src/hierarchies/priors/hyperparams.h" + +AbstractUpdater::ProtoHypersPtr NNIGUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + int card = likecast.get_card(); + double data_sum = likecast.get_data_sum(); + double data_sum_squares = likecast.get_data_sum_squares(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + return priorcast.get_hypers_proto(); + } + + // Compute posterior hyperparameters + double mean, var_scaling, shape, scale; + double y_bar = data_sum / (1.0 * card); // sample mean + double ss = data_sum_squares - card * y_bar * y_bar; + mean = (hypers.var_scaling * hypers.mean + data_sum) / + (hypers.var_scaling + card); + var_scaling = hypers.var_scaling + card; + shape = hypers.shape + 0.5 * card; + scale = hypers.scale + 0.5 * ss + + 0.5 * hypers.var_scaling * card * (y_bar - hypers.mean) * + (y_bar - hypers.mean) / (card + hypers.var_scaling); + + // Proto conversion + ProtoHypers out; + out.mutable_nnig_state()->set_mean(mean); + out.mutable_nnig_state()->set_var_scaling(var_scaling); + out.mutable_nnig_state()->set_shape(shape); + out.mutable_nnig_state()->set_scale(scale); + return std::make_shared(out); +} diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h new file mode 100644 index 000000000..5866735ba --- /dev/null +++ b/src/hierarchies/updaters/nnig_updater.h @@ -0,0 +1,36 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ + +#include "semi_conjugate_updater.h" +#include "src/hierarchies/likelihoods/uni_norm_likelihood.h" +#include "src/hierarchies/priors/nig_prior_model.h" + +/** + * Updater specific for the `UniNormLikelihood` used in combination + * with `NIGPriorModel`, that is the model + * + * \f[ + * y_i \mid \mu, \sigma^2 &\stackrel{\small\mathrm{iid}}{\sim} N(\mu, + * \sigma^2) \\ + * \mu \mid \sigma^2 &\sim N(\mu_0, \sigma^2 / \lambda) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * It exploits the conjugacy of the model to sample the full conditional of + * \f$ (\mu, \sigma^2) \f$ by calling `NIGPriorModel::sample` with updated + * parameters + */ + +class NNIGUpdater + : public SemiConjugateUpdater { + public: + NNIGUpdater() = default; + ~NNIGUpdater() = default; + + bool is_conjugate() const override { return true; }; + + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnw_updater.cc b/src/hierarchies/updaters/nnw_updater.cc new file mode 100644 index 000000000..f265f84f4 --- /dev/null +++ b/src/hierarchies/updaters/nnw_updater.cc @@ -0,0 +1,49 @@ +#include "nnw_updater.h" + +#include "algorithm_state.pb.h" +#include "src/hierarchies/likelihoods/states/includes.h" +#include "src/hierarchies/priors/hyperparams.h" +#include "src/utils/proto_utils.h" + +AbstractUpdater::ProtoHypersPtr NNWUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + int card = likecast.get_card(); + Eigen::VectorXd data_sum = likecast.get_data_sum(); + Eigen::MatrixXd data_sum_squares = likecast.get_data_sum_squares(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + return prior.get_hypers_proto(); + } + + // Compute posterior hyperparameters + Eigen::VectorXd mean; + double var_scaling, deg_free; + Eigen::MatrixXd scale, scale_inv, scale_chol; + var_scaling = hypers.var_scaling + card; + deg_free = hypers.deg_free + card; + Eigen::VectorXd mubar = data_sum.array() / card; // sample mean + mean = (hypers.var_scaling * hypers.mean + card * mubar) / + (hypers.var_scaling + card); + // Compute tau_n + Eigen::MatrixXd tau_temp = + data_sum_squares - card * mubar * mubar.transpose(); + tau_temp += (card * hypers.var_scaling / (card + hypers.var_scaling)) * + (mubar - hypers.mean) * (mubar - hypers.mean).transpose(); + scale_inv = tau_temp + hypers.scale_inv; + scale = stan::math::inverse_spd(scale_inv); + + // Proto conversion + ProtoHypers out; + bayesmix::to_proto(mean, out.mutable_nnw_state()->mutable_mean()); + out.mutable_nnw_state()->set_var_scaling(var_scaling); + out.mutable_nnw_state()->set_deg_free(deg_free); + bayesmix::to_proto(scale, out.mutable_nnw_state()->mutable_scale()); + return std::make_shared(out); +} diff --git a/src/hierarchies/updaters/nnw_updater.h b/src/hierarchies/updaters/nnw_updater.h new file mode 100644 index 000000000..b7877274d --- /dev/null +++ b/src/hierarchies/updaters/nnw_updater.h @@ -0,0 +1,36 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ + +#include "semi_conjugate_updater.h" +#include "src/hierarchies/likelihoods/multi_norm_likelihood.h" +#include "src/hierarchies/priors/nw_prior_model.h" + +/** + * Updater specific for the `MultiNormLikelihood` used in combination + * with `NWPriorModel`, that is the model + * + * \f[ + * y_i \mid \bm{\mu}, \Sigma &\stackrel{\small\mathrm{iid}}{\sim} + * N_d(\bm{mu}, \Sigma) \\ + * \bm{\mu} \mid \Sigma &\sim N_d(\bm{\mu}_0, \Sigma / \lambda) \\ + * \Sigma^{-1} &\sim Wishart(\nu, \Psi) + * \f] + * + * It exploits the conjugacy of the model to sample the full conditional of + * \f$ (\bm{\mu}, \Sigma) \f$ by calling `NWPriorModel::sample` with updated + * parameters. + */ + +class NNWUpdater + : public SemiConjugateUpdater { + public: + NNWUpdater() = default; + ~NNWUpdater() = default; + + bool is_conjugate() const override { return true; }; + + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnxig_updater.cc b/src/hierarchies/updaters/nnxig_updater.cc new file mode 100644 index 000000000..84f91c73c --- /dev/null +++ b/src/hierarchies/updaters/nnxig_updater.cc @@ -0,0 +1,41 @@ +#include "nnxig_updater.h" + +#include "src/hierarchies/likelihoods/states/includes.h" +#include "src/hierarchies/priors/hyperparams.h" + +AbstractUpdater::ProtoHypersPtr NNxIGUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + auto state = likecast.get_state(); + int card = likecast.get_card(); + double data_sum = likecast.get_data_sum(); + double data_sum_squares = likecast.get_data_sum_squares(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + return priorcast.get_hypers_proto(); + } + + // Compute posterior hyperparameters + double mean, var, shape, scale; + double var_y = data_sum_squares - 2 * state.mean * data_sum + + card * state.mean * state.mean; + mean = (hypers.var * data_sum + state.var * hypers.mean) / + (card * hypers.var + state.var); + var = (state.var * hypers.var) / (card * hypers.var + state.var); + shape = hypers.shape + 0.5 * card; + scale = hypers.scale + 0.5 * var_y; + + // Proto conversion + ProtoHypers out; + out.mutable_nnxig_state()->set_mean(mean); + out.mutable_nnxig_state()->set_var(var); + out.mutable_nnxig_state()->set_shape(shape); + out.mutable_nnxig_state()->set_scale(scale); + return std::make_shared(out); +} diff --git a/src/hierarchies/updaters/nnxig_updater.h b/src/hierarchies/updaters/nnxig_updater.h new file mode 100644 index 000000000..195b8c44f --- /dev/null +++ b/src/hierarchies/updaters/nnxig_updater.h @@ -0,0 +1,34 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ + +#include "semi_conjugate_updater.h" +#include "src/hierarchies/likelihoods/uni_norm_likelihood.h" +#include "src/hierarchies/priors/nxig_prior_model.h" + +/** + * Updater specific for the `UniNormLikelihood` used in combination + * with `NxIGPriorModel`, that is the model + * + * \f[ + * y_i \mid \mu, \sigma^2 &\stackrel{\small\mathrm{iid}}{\sim} N(\mu, + * \sigma^2) \\ + * \mu &\sim N(\mu_0, \eta^2) \\ + * \sigma^2 & \sim InvGamma(a,b) + * \f] + * + * It exploits the semi-conjugacy of the model to sample the full conditional + * of \f$ (\mu, \sigma^2) \f$ by calling `NxIGPriorModel::sample` with updated + * parameters + */ + +class NNxIGUpdater + : public SemiConjugateUpdater { + public: + NNxIGUpdater() = default; + ~NNxIGUpdater() = default; + + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ diff --git a/src/hierarchies/updaters/random_walk_updater.h b/src/hierarchies/updaters/random_walk_updater.h new file mode 100644 index 000000000..cb3076206 --- /dev/null +++ b/src/hierarchies/updaters/random_walk_updater.h @@ -0,0 +1,78 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_RANDOM_WALK_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_RANDOM_WALK_UPDATER_H_ + +#include "metropolis_updater.h" + +/** + * Metropolis-Hastings updater using an isotropic proposal function + * centered in the current value of the parameters (unconstrained). + * This class requires that the Hierarchy's state implements + * the `get_unconstrained()`, `set_from_unconstrained()` and + * `log_det_jac()` functions. + * + * Given the current value of the unconstrained parameters \f$ x \f$, a new + * value is proposed from + * + * \f[ + * x_{new} \sim N(x, step\_size \cdot I) + * \f] + * + * and then either accepted (in which case the hierarchy's state is + * set to \f$ x_{new} \f$) or rejected. + */ + +class RandomWalkUpdater : public MetropolisUpdater { + protected: + double step_size; + + public: + RandomWalkUpdater() = default; + ~RandomWalkUpdater() = default; + + RandomWalkUpdater(double step_size) : step_size(step_size) {} + + //! Samples from the proposal distribution + //! @param curr_state the current state (unconstrained parametrization) + //! @param like instance of likelihood + //! @param prior instance of prior + //! @param target_lpdf either double or stan::math::var. Needed for + //! stan's automatic differentiation. It is not used here. + template + Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, + AbstractLikelihood &like, + AbstractPriorModel &prior, F &target_lpdf) { + Eigen::VectorXd step(curr_state.size()); + auto &rng = bayesmix::Rng::Instance().get(); + for (int i = 0; i < curr_state.size(); i++) { + step(i) = stan::math::normal_rng(0, step_size, rng); + } + return curr_state + step; + } + + //! Evaluates the log probability density function of the proposal + //! @param prop_state the proposed state (at which to evaluate the lpdf) + //! @param curr_state the current state (unconstrained parametrization) + //! @param like instance of likelihood + //! @param prior instance of prior + //! @param target_lpdf either double or stan::math::var. Needed for + //! stan's automatic differentiation. It is not used here. + template + double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, + AbstractLikelihood &like, AbstractPriorModel &prior, + F &target_lpdf) { + double out; + for (int i = 0; i < prop_state.size(); i++) { + out += stan::math::normal_lpdf(prop_state(i), curr_state(i), step_size); + } + return out; + } + + //! Returns a shared_ptr to a new instance of `this` + std::shared_ptr clone() const { + auto out = std::make_shared( + static_cast(*this)); + return out; + } +}; + +#endif diff --git a/src/hierarchies/updaters/semi_conjugate_updater.h b/src/hierarchies/updaters/semi_conjugate_updater.h new file mode 100644 index 000000000..5609bf1b8 --- /dev/null +++ b/src/hierarchies/updaters/semi_conjugate_updater.h @@ -0,0 +1,86 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_SEMI_CONJUGATE_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_SEMI_CONJUGATE_UPDATER_H_ + +#include + +#include "abstract_updater.h" +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" + +//! Updater for semi-conjugate hierarchies. +//! +//! We say that a hierarchy is semi-conjugate if the full conditionals +//! of each parameter is in the same parametric family of the prior +//! distribution of that parameter. +//! +//! As a consequence, sampling from the full conditional can be done +//! by calling the `sample` method from the `PriorModel` class, with +//! updater hyperparameters +//! +//! Classes inheriting from this one should only implement the +//! `compute_posterior_hypers(...)` member function +//! +//! This class is templated with respect to +//! @tparam Likelihood: the likelihood of the hierarchy, instance of +//! `AbstractLikelihood` +//! @tparam PriorModel: the prior of the hierarchy, instance of +//! `AbstractPriorModel` + +template +class SemiConjugateUpdater : public AbstractUpdater { + public: + SemiConjugateUpdater() = default; + + ~SemiConjugateUpdater() = default; + + void draw(AbstractLikelihood& like, AbstractPriorModel& prior, + bool update_params) override; + + //! Used by algorithms such as Neal3 and SplitMerge + //! It stores the hyperparameters computed by `compute_posterior_hypers` + void save_posterior_hypers(ProtoHypersPtr post_hypers_) override; + + protected: + Likelihood& downcast_likelihood(AbstractLikelihood& like_); + PriorModel& downcast_prior(AbstractPriorModel& prior_); + ProtoHypersPtr post_hypers = std::make_shared(); +}; + +// Methods' definitions +template +Likelihood& SemiConjugateUpdater::downcast_likelihood( + AbstractLikelihood& like_) { + return static_cast(like_); +} + +template +PriorModel& SemiConjugateUpdater::downcast_prior( + AbstractPriorModel& prior_) { + return static_cast(prior_); +} + +template +void SemiConjugateUpdater::draw( + AbstractLikelihood& like, AbstractPriorModel& prior, bool update_params) { + // Likelihood and PriorModel downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + // Sample from the full conditional of a semi-conjugate hierarchy + bool set_card = true; + if (likecast.get_card() == 0) { + likecast.set_state(priorcast.sample(), !set_card); + } else { + auto post_params = compute_posterior_hypers(likecast, priorcast); + likecast.set_state(priorcast.sample(post_params), !set_card); + if (update_params) save_posterior_hypers(post_params); + } +} + +template +void SemiConjugateUpdater::save_posterior_hypers( + ProtoHypersPtr post_hypers_) { + post_hypers = post_hypers_; + return; +} + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_SEMI_CONJUGATE_UPDATER_H_ diff --git a/src/hierarchies/updaters/target_lpdf_unconstrained.h b/src/hierarchies/updaters/target_lpdf_unconstrained.h new file mode 100644 index 000000000..803900904 --- /dev/null +++ b/src/hierarchies/updaters/target_lpdf_unconstrained.h @@ -0,0 +1,31 @@ +#ifndef BAYESMIX_SRC_HIERARCHIES_UPDATERS_TARGET_LPDF_UNCONSTRAINED_H_ +#define BAYESMIX_SRC_HIERARCHIES_UPDATERS_TARGET_LPDF_UNCONSTRAINED_H_ + +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" + +//! Functor that computes the log-full conditional distribution +//! of a specific hierarchy. +//! Used by metropolis-like updaters especially when the gradient +//! of the target_lpdf if required +class target_lpdf_unconstrained { + protected: + AbstractLikelihood* like; + AbstractPriorModel* prior; + + public: + target_lpdf_unconstrained(AbstractLikelihood* like, + AbstractPriorModel* prior) + : like(like), prior(prior) {} + + //! Computes the log-full conditional that is simply the + //! sum of `cluster_lpdf_from_unconstrained` in `AbstractLikelihood` + //! and `lpdf_from_unconstrained` in `AbstractPriorModel` + template + T operator()(const Eigen::Matrix& x) const { + return like->cluster_lpdf_from_unconstrained(x) + + prior->lpdf_from_unconstrained(x); + } +}; + +#endif diff --git a/src/includes.h b/src/includes.h index ec91db35e..d4d19b8ca 100644 --- a/src/includes.h +++ b/src/includes.h @@ -9,12 +9,13 @@ #include "algorithms/neal8_algorithm.h" #include "collectors/file_collector.h" #include "collectors/memory_collector.h" +#include "hierarchies/fa_hierarchy.h" #include "hierarchies/lapnig_hierarchy.h" #include "hierarchies/lin_reg_uni_hierarchy.h" #include "hierarchies/load_hierarchies.h" -#include "hierarchies/fa_hierarchy.h" #include "hierarchies/nnig_hierarchy.h" #include "hierarchies/nnw_hierarchy.h" +#include "hierarchies/nnxig_hierarchy.h" #include "mixings/dirichlet_mixing.h" #include "mixings/load_mixings.h" #include "mixings/logit_sb_mixing.h" diff --git a/src/mixings/abstract_mixing.h b/src/mixings/abstract_mixing.h index 52863dbdb..bfb2a0e33 100644 --- a/src/mixings/abstract_mixing.h +++ b/src/mixings/abstract_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "mixing_id.pb.h" diff --git a/src/mixings/base_mixing.h b/src/mixings/base_mixing.h index c9b884c06..7fedf25af 100644 --- a/src/mixings/base_mixing.h +++ b/src/mixings/base_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include "abstract_mixing.h" diff --git a/src/mixings/dirichlet_mixing.h b/src/mixings/dirichlet_mixing.h index 4812df4f9..f1032624c 100644 --- a/src/mixings/dirichlet_mixing.h +++ b/src/mixings/dirichlet_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" @@ -12,24 +12,30 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Class that represents the EPPF induced by the Dirithclet process (DP) -//! introduced in Ferguson (1973), see also Sethuraman (1994). -//! The EPPF induced by the DP depends on a `totalmass` parameter M. -//! Given a clustering of n elements into k clusters, each with cardinality -//! n_j, j=1, ..., k, the EPPF of the DP gives the following probabilities for -//! the cluster membership of the (n+1)-th observation: -//! p(j-th cluster | ...) = n_j / (n + M) -//! p(k+1-th cluster | ...) = M / (n + M) -//! The state is solely composed of M, but we also store log(M) for efficiency -//! reasons. For more information about the class, please refer instead to base -//! classes, `AbstractMixing` and `BaseMixing`. - namespace Dirichlet { struct State { double totalmass, logtotmass; }; }; // namespace Dirichlet +/** + * Class that represents the EPPF induced by the Dirithclet process (DP) + * introduced in Ferguson (1973), see also Sethuraman (1994). + * The EPPF induced by the DP depends on a `totalmass` parameter M. + * Given a clustering of n elements into k clusters, each with cardinality + * \f$ n_j, j=1, ..., k \f$ the EPPF of the DP gives the following + * probabilities for the cluster membership of the (n+1)-th observation: + * + * \f[ + * p(\text{j-th cluster} | ...) &= n_j / (n + M) \\ + * p(\text{new cluster} | ...) &= M / (n + M) + * \f] + * + * The state is solely composed of M, but we also store log(M) for efficiency + * reasons. For more information about the class, please refer instead to base + * classes, `AbstractMixing` and `BaseMixing`. + */ + class DirichletMixing : public BaseMixing { public: @@ -60,6 +66,7 @@ class DirichletMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/mixings/logit_sb_mixing.cc b/src/mixings/logit_sb_mixing.cc index ed7514f58..a3711086d 100644 --- a/src/mixings/logit_sb_mixing.cc +++ b/src/mixings/logit_sb_mixing.cc @@ -2,10 +2,10 @@ #include -#include #include #include #include +#include #include #include "mixing_prior.pb.h" diff --git a/src/mixings/logit_sb_mixing.h b/src/mixings/logit_sb_mixing.h index 40727cf57..1061029c7 100644 --- a/src/mixings/logit_sb_mixing.h +++ b/src/mixings/logit_sb_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" @@ -12,30 +12,35 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Class that represents the logit stick-breaking process indroduced in Rigon -//! and Durante (2020). -//! That is, a prior for weights (w_1,...,w_H), depending on covariates x in -//! R^p, in the H-1 dimensional unit simplex, defined as follows: -//! w_1(x) = v_1(x) -//! w_j(x) = v_j(x) (1 - v_1(x)) ... (1 - v_{j-1}(x)), for j=2, ... H-1 -//! w_H(x) = 1 - (w_1(x) + w_2 + ... + w_{H-1}(x)) -//! and -//! v_j(x) = 1 / exp(- ), for j = 1, ..., H-1 -//! -//! The main difference with the mentioned paper is that the authors propose a -//! Gibbs sampler in which the full conditionals are available in close form -//! thanks to a Polya-Gamma augmentation. Here instead, a Metropolis-adjusted -//! Langevin algorithm (MALA) step is used. The step-size of the MALA step must -//! be passed in the LogSBPrior Protobuf message. -//! For more information about the class, please refer instead to base classes, -//! `AbstractMixing` and `BaseMixing`. - namespace LogitSB { struct State { Eigen::MatrixXd regression_coeffs, precision; }; }; // namespace LogitSB +/** + * Class that represents the logit stick-breaking process indroduced in Rigon + * and Durante (2020). + * That is, a prior for weights \f$ (w_1,\dots,w_H) \f$, depending on + * covariates \f$ x \f$ in \f$ \mathbb{R}^p \f$, in the H-1 dimensional unit + * simplex, defined as follows: + * + * \f[ + * w_1 &= v_1 \\ + * w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ + * w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) \\ + * v_j(x) &= 1 / exp(- <\alpha_j, x> ), for j = 1, ..., H-1 + * \f] + * + * The main difference with the mentioned paper is that the authors propose a + * Gibbs sampler in which the full conditionals are available in close form + * thanks to a Polya-Gamma augmentation. Here instead, a Metropolis-adjusted + * Langevin algorithm (MALA) step is used. The step-size of the MALA step must + * be passed in the LogSBPrior Protobuf message. + * For more information about the class, please refer instead to base classes, + * `AbstractMixing` and `BaseMixing`. + */ + class LogitSBMixing : public BaseMixing { public: diff --git a/src/mixings/mixture_finite_mixing.h b/src/mixings/mixture_finite_mixing.h index 9813f986a..67ce224dc 100644 --- a/src/mixings/mixture_finite_mixing.h +++ b/src/mixings/mixture_finite_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" @@ -12,31 +12,42 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Class that represents the Mixture of Finite Mixtures (MFM) [1] -//! The basic idea is to take usual finite mixture model with Dirichlet weights -//! and put a prior (Poisson) on the number of components. The EPPF induced by -//! MFM depends on a Dirichlet parameter 'gamma' and number V_n(t), where -//! V_n(t) depends on the Poisson rate parameter 'lambda'. -//! V_n(t) = sum_{k=1}^{inf} ( k_(t)*p_K(k) / (gamma*k)^(n) ) -//! Given a clustering of n elements into k clusters, each with cardinality -//! n_j, j=1, ..., k, the EPPF of the MFM gives the following probabilities for -//! the cluster membership of the (n+1)-th observation: -//! denominator = n_j + gamma / (n + gamma*(n_clust + V[n_clust+1]/V[n_clust])) -//! p(j-th cluster | ...) = (n_j + gamma) / denominator -//! p(k+1-th cluster | ...) = V[n_clust+1]/V[n_clust]*gamma / denominator -//! For numerical reasons each value of V is multiplied with a constant C -//! computed as the first term of the series of V_n[0]. -//! For more information about the class, please refer instead to base -//! classes, `AbstractMixing` and `BaseMixing`. -//! [1] "Mixture Models with a Prior on the Number of Components", J.W.Miller -//! and M.T.Harrison, 2015, arXiv:1502.06241v1 - namespace Mixture_Finite { struct State { double lambda, gamma; }; }; // namespace Mixture_Finite +/** + * Class that represents the Mixture of Finite Mixtures (MFM) [1] + * The basic idea is to take usual finite mixture model with Dirichlet weights + * and put a prior (Poisson) on the number of components. The EPPF induced by + * MFM depends on a Dirichlet parameter 'gamma' and number \f$V_n(t)\f$, where + * \f$ V_n(t) \f$ depends on the Poisson rate parameter 'lambda'. + * + * \f[ + * V_n(t) = \sum_{k=1}^{\infty} ( k_(t)p_K(k) / (\gamma*k)^(n) ) + * \f] + * + * Given a clustering of n elements into k clusters, each with cardinality + * \f$ n_j, j=1, ..., k \f$, the EPPF of the MFM gives the following + * probabilities for the cluster membership of the (n+1)-th observation: + * + * \f[ + * p(\text{j-th cluster} | ...) &= (n_j + \gamma) / D \\ + * p(\text{k+1-th cluster} | ...) &= V[k+1]/V[k] \gamma / D \\ + * D &= n_j + \gamma / (n + \gamma * (k + V[k+1]/V[k])) + * \f] + * + * For numerical reasons each value of V is multiplied with a constant C + * computed as the first term of the series of V_n[0]. + * For more information about the class, please refer instead to base + * classes, `AbstractMixing` and `BaseMixing`. + * + * [1] "Mixture Models with a Prior on the Number of Components", J.W.Miller + * and M.T.Harrison, 2015, arXiv:1502.06241v1 + */ + class MixtureFiniteMixing : public BaseMixing { @@ -70,6 +81,7 @@ class MixtureFiniteMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/mixings/pityor_mixing.cc b/src/mixings/pityor_mixing.cc index ffdc98c6e..2dc3bdeaa 100644 --- a/src/mixings/pityor_mixing.cc +++ b/src/mixings/pityor_mixing.cc @@ -2,9 +2,9 @@ #include -#include #include #include +#include #include #include "mixing_prior.pb.h" diff --git a/src/mixings/pityor_mixing.h b/src/mixings/pityor_mixing.h index ba4598f52..54c68c05d 100644 --- a/src/mixings/pityor_mixing.h +++ b/src/mixings/pityor_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" @@ -12,26 +12,31 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Class that represents the Pitman-Yor process (PY) in Pitman and Yor (1997). -//! The EPPF induced by the PY depends on a `strength` parameter M and a -//! `discount` paramter d. -//! Given a clustering of n elements into k clusters, each with cardinality -//! n_j, j=1, ..., k, the EPPF of the PY gives the following probabilities for -//! the cluster membership of the (n+1)-th observation: -//! p(j-th cluster | ...) \propto (n_j - d) -//! p(k+1-th cluster | ...) \propto M + k * d -//! -//! When `discount=0`, the EPPF of the PY process coincides with the one of the -//! DP with totalmass = strength. -//! For more information about the class, please refer instead to base classes, -//! `AbstractMixing` and `BaseMixing`. - namespace PitYor { struct State { double strength, discount; }; }; // namespace PitYor +/** + * Class that represents the Pitman-Yor process (PY) in Pitman and Yor (1997). + * The EPPF induced by the PY depends on a `strength` parameter M and a + * `discount` paramter d. + * Given a clustering of n elements into k clusters, each with cardinality + * \f$ n_j, j=1, ..., k \f$, the EPPF of the PY gives the following + * probabilities for the cluster membership of the (n+1)-th observation: + * + * \f[ + * p(\text{j-th cluster} | ...) \propto (n_j - d) \\ + * p(\text{new cluster} | ...) \propto M + k d + * \f] + * + * When `discount=0`, the EPPF of the PY process coincides with the one of the + * DP with totalmass = strength. + * For more information about the class, please refer instead to base classes, + * `AbstractMixing` and `BaseMixing`. + */ + class PitYorMixing : public BaseMixing { public: @@ -62,6 +67,7 @@ class PitYorMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/mixings/truncated_sb_mixing.cc b/src/mixings/truncated_sb_mixing.cc index efe73e3d5..b826d5582 100644 --- a/src/mixings/truncated_sb_mixing.cc +++ b/src/mixings/truncated_sb_mixing.cc @@ -2,11 +2,11 @@ #include -#include #include #include #include #include +#include #include #include "logit_sb_mixing.h" diff --git a/src/mixings/truncated_sb_mixing.h b/src/mixings/truncated_sb_mixing.h index 27e16c7fd..4d7d190fb 100644 --- a/src/mixings/truncated_sb_mixing.h +++ b/src/mixings/truncated_sb_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" @@ -12,30 +12,37 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Class that represents a truncated stick-breaking process, as shown in -//! Ishwaran and James (2001). -//! -//! A truncated stick-breaking process is a prior for weights (w_1,...,w_H) in -//! the H-1 dimensional unit simplex, and is defined as follows: -//! w_1 = v_1 -//! w_j = v_j (1 - v_1) ... (1 - v_{j-1}), for j=1, ... H-1 -//! w_H = 1 - (w_1 + w_2 + ... + w_{H-1}) -//! The v_j's are called sticks and we assume them to be independently -//! distributed as v_j ~ Beta(a_j, b_j). -//! -//! When a_j = 1 and b_j = M, the stick-breaking process is a truncation of the -//! stick-breaking representation of the DP. -//! When a_j = 1 - d and b_j = M + i*d, it is the trunctation of a PY process. -//! Its state is composed of the weights w_j in log-scale and the sticks v_j. -//! For more information about the class, please refer instead to base classes, -//! `AbstractMixing` and `BaseMixing`. - namespace TruncSB { struct State { Eigen::VectorXd sticks, logweights; }; }; // namespace TruncSB +/** + * Class that represents a truncated stick-breaking process, as shown in + * Ishwaran and James (2001). + * + * A truncated stick-breaking process is a prior for weights + * \f$ (w_1,...,w_H) \f$ in the H-1 dimensional unit simplex, and is defined as + * follows: + * + * \f[ + * w_1 &= v_1 \\ + * w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ + * w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) + * \f] + * + * The \f$ v_j \f$'s are called sticks and we assume them to be independently + * distributed as \f$ v_j \sim \text{Beta}(a_j, b_j) \f$. + * + * When \f$ a_j = 1 \f$ and \f$ b_j = M \f$, the stick-breaking process is a + * truncation of the stick-breaking representation of the DP. + * When \f$ a_j = 1-d \f$ and \f$ b_j = M+id \f$, it is the trunctation of a PY + * process. Its state is composed of the weights \f$ w_j \f$ in log-scale and + * the sticks \f$ v_j \f$. For more information about the class, please refer + * instead to base classes, `AbstractMixing` and `BaseMixing`. + */ + class TruncatedSBMixing : public BaseMixing { public: diff --git a/src/proto/algorithm_state.proto b/src/proto/algorithm_state.proto index 39df18641..d106d8cf9 100644 --- a/src/proto/algorithm_state.proto +++ b/src/proto/algorithm_state.proto @@ -38,11 +38,11 @@ message AlgorithmState { message HierarchyHypers { // Current values of the Hyperparameters of the Hierarchy oneof val { - EmptyPrior fake_prior = 1; + Vector general_state = 1; NIGDistribution nnig_state = 2; NWDistribution nnw_state = 3; MultiNormalIGDistribution lin_reg_uni_state = 4; - LapNIGState lapnig_state = 6; + NxIGDistribution nnxig_state = 5; FAPriorDistribution fa_state = 7; } } diff --git a/src/proto/distribution.proto b/src/proto/distribution.proto index e27e89a29..c123765a9 100644 --- a/src/proto/distribution.proto +++ b/src/proto/distribution.proto @@ -57,6 +57,18 @@ message NIGDistribution { double scale = 4; } +/* + * Parameters of a Normal x Inverse-Gamma distribution + * with density + * f(x, y) = N(x | mu, var) * IG(y | shape, scale) + */ +message NxIGDistribution { + double mean = 1; + double var = 2; + double shape = 3; + double scale = 4; +} + /* * Parameters of a Normal Wishart distribution * with density diff --git a/src/proto/hierarchy_id.proto b/src/proto/hierarchy_id.proto index e8047d1dc..a6817aa5b 100644 --- a/src/proto/hierarchy_id.proto +++ b/src/proto/hierarchy_id.proto @@ -12,4 +12,5 @@ enum HierarchyId { LinRegUni = 3; // Linear Regression (univariate response) LapNIG = 4; // Laplace - Normal Inverse Gamma FA = 5; // Factor Analysers + NNxIG = 6; // Normal - Normal x Inverse Gamma } diff --git a/src/proto/hierarchy_prior.proto b/src/proto/hierarchy_prior.proto index 2cd76fda4..866189f6a 100644 --- a/src/proto/hierarchy_prior.proto +++ b/src/proto/hierarchy_prior.proto @@ -31,6 +31,16 @@ message NNIGPrior { } } +/* + * Prior for the parameters of the base measure in a Normal-Normal x Inverse Gamma hierarchy + */ +message NNxIGPrior { + + oneof prior { + NxIGDistribution fixed_values = 1; // no prior, just fixed values + } +} + /* * Prior for the parameters of the base measure in a Laplace - Normal Inverse Gamma hierarchy */ @@ -75,7 +85,6 @@ message NNWPrior { } } - /* * Prior for the parameters of the base measure in a Normal mixture model with a covariate-dependent * location. diff --git a/src/runtime/factory.h b/src/runtime/factory.h index f496bd42f..8a4bbba31 100644 --- a/src/runtime/factory.h +++ b/src/runtime/factory.h @@ -77,7 +77,7 @@ class Factory { //! Adds a builder function to the storage //! @param id Identifier to associate the builder with - //! @param bulider Builder function for a specific object type + //! @param builder Builder function for a specific object type void add_builder(const Identifier &id, const Builder &builder) { storage.insert(std::make_pair(id, builder)); } diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt index dfec7c2a9..c891c24cf 100644 --- a/src/utils/CMakeLists.txt +++ b/src/utils/CMakeLists.txt @@ -15,4 +15,5 @@ target_sources(bayesmix rng.h testing_utils.h testing_utils.cc + covariates_getter.h ) diff --git a/src/utils/cluster_utils.cc b/src/utils/cluster_utils.cc index d1c63503c..89c65c39b 100644 --- a/src/utils/cluster_utils.cc +++ b/src/utils/cluster_utils.cc @@ -1,7 +1,7 @@ #include "cluster_utils.h" -#include #include +#include #include "lib/progressbar/progressbar.h" #include "proto_utils.h" @@ -35,6 +35,7 @@ Eigen::VectorXi bayesmix::cluster_estimate( std::cout << "Done)" << std::endl; // Compute Frobenius norm error of all iterations + std::cout << "Computing Frobenius norm error... " << std::endl; Eigen::VectorXd errors(n_iter); for (int k = 0; k < n_iter; k++) { for (int i = 0; i < n_data; i++) { @@ -48,6 +49,7 @@ Eigen::VectorXi bayesmix::cluster_estimate( bar.display(); } bar.done(); + std::cout << "Done" << std::endl; // Print Ending Message // Find iteration with the least error std::ptrdiff_t ibest; diff --git a/src/utils/cluster_utils.h b/src/utils/cluster_utils.h index c4466db8c..1930bab16 100644 --- a/src/utils/cluster_utils.h +++ b/src/utils/cluster_utils.h @@ -1,17 +1,20 @@ #ifndef BAYESMIX_UTILS_CLUSTER_UTILS_H_ #define BAYESMIX_UTILS_CLUSTER_UTILS_H_ -#include +#include -//! This file includes some utilities for cluster estimation. These functions -//! only use Eigen ojects. +//! \file cluster_utils.h +//! The `cluster_utils.h` file includes some utilities for cluster estimation. +//! These functions only use Eigen objects. namespace bayesmix { + //! Computes the posterior similarity matrix the data Eigen::MatrixXd posterior_similarity(const Eigen::MatrixXd &alloc_chain); //! Estimates the clustering structure of the data via LS minimization Eigen::VectorXi cluster_estimate(const Eigen::MatrixXi &alloc_chain); + } // namespace bayesmix #endif // BAYESMIX_UTILS_CLUSTER_UTILS_H_ diff --git a/src/utils/covariates_getter.h b/src/utils/covariates_getter.h new file mode 100644 index 000000000..530590182 --- /dev/null +++ b/src/utils/covariates_getter.h @@ -0,0 +1,25 @@ +#ifndef BAYESMIX_SRC_UTILS_COVARIATES_GETTER_H +#define BAYESMIX_SRC_UTILS_COVARIATES_GETTER_H + +#include + +class covariates_getter { + protected: + const Eigen::MatrixXd* covariates; + + public: + covariates_getter(const Eigen::MatrixXd& covariates_) + : covariates(&covariates_){}; + + Eigen::RowVectorXd operator()(const size_t& i) const { + if (covariates->cols() == 0) { + return Eigen::RowVectorXd(0); + } else if (covariates->rows() == 1) { + return covariates->row(0); + } else { + return covariates->row(i); + } + }; +}; + +#endif // BAYESMIX_SRC_UTILS_COVARIATES_GETTER_H diff --git a/src/utils/distributions.cc b/src/utils/distributions.cc index 5b6e0762d..c4536c1a8 100644 --- a/src/utils/distributions.cc +++ b/src/utils/distributions.cc @@ -1,10 +1,10 @@ #include "distributions.h" -#include #include #include #include #include +#include #include "src/utils/proto_utils.h" @@ -18,22 +18,24 @@ double bayesmix::multi_normal_prec_lpdf(const Eigen::VectorXd &datum, const Eigen::MatrixXd &prec_chol, const double prec_logdet) { using stan::math::NEG_LOG_SQRT_TWO_PI; - double base = prec_logdet + NEG_LOG_SQRT_TWO_PI * datum.size(); - double exp = (prec_chol * (datum - mean)).squaredNorm(); - return 0.5 * (base - exp); + double base = 0.5 * prec_logdet + NEG_LOG_SQRT_TWO_PI * datum.size(); + double exp = 0.5 * (prec_chol.transpose() * (datum - mean)).squaredNorm(); + return base - exp; } Eigen::VectorXd bayesmix::multi_normal_prec_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, const double prec_logdet) { using stan::math::NEG_LOG_SQRT_TWO_PI; - Eigen::VectorXd exp = - ((data.rowwise() - mean.transpose()) * prec_chol.transpose()) - .rowwise() - .squaredNorm(); - Eigen::VectorXd base = Eigen::ArrayXd::Ones(data.rows()) * prec_logdet + - NEG_LOG_SQRT_TWO_PI * data.cols(); - return (base - exp) * 0.5; + + Eigen::VectorXd exp = ((data.rowwise() - mean.transpose()) * prec_chol) + .rowwise() + .squaredNorm() * + 0.5; + Eigen::VectorXd base = + Eigen::ArrayXd::Ones(data.rows()) * prec_logdet * 0.5 + + NEG_LOG_SQRT_TWO_PI * data.cols(); + return base - exp; } Eigen::VectorXd bayesmix::multi_normal_diag_rng( diff --git a/src/utils/distributions.h b/src/utils/distributions.h index 20f5b587d..d6b7a0635 100644 --- a/src/utils/distributions.h +++ b/src/utils/distributions.h @@ -1,20 +1,21 @@ #ifndef BAYESMIX_UTILS_DISTRIBUTIONS_H_ #define BAYESMIX_UTILS_DISTRIBUTIONS_H_ -#include #include +#include #include #include "algorithm_state.pb.h" -//! This file includes several useful functions related to probability -//! distributions, including categorical variables, popular multivariate -//! distributions, and distribution distances. Some of these functions make use -//! of OpenMP parallelism to achieve better efficiency. +//! @file distributions.h +//! The `distributions.h` file includes several useful functions related to +//! probability distributions, including categorical variables, popular +//! multivariate distributions, and distribution distances. Some of these +//! functions make use of OpenMP parallelism to achieve better efficiency. namespace bayesmix { -/* +/** * Returns a pseudorandom categorical random variable on the set * {start, ..., start + k} where k is the size of the given probability vector * @@ -26,64 +27,66 @@ namespace bayesmix { int categorical_rng(const Eigen::VectorXd &probas, std::mt19937_64 &rng, const int start = 0); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution parametrized by mean and precision matrix on a single point * - * @param datum Point in which to evaluate the the lpdf - * @param mean The mean of the Gaussian distribution - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @prec_logdet The logarithm of the determinant of the precision matrix - * @return The evaluation of the lpdf + * @param datum Point in which to evaluate the the lpdf + * @param mean The mean of the Gaussian distribution + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param prec_logdet The logarithm of the determinant of the precision + * matrix + * @return The evaluation of the lpdf */ double multi_normal_prec_lpdf(const Eigen::VectorXd &datum, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, const double prec_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution parametrized by mean and precision matrix on multiple points * - * @param data Grid of points (by row) on which to evaluate the lpdf - * @param mean The mean of the Gaussian distribution - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @prec_logdet The logarithm of the determinant of the precision matrix - * @return The evaluation of the lpdf + * @param data Grid of points (by row) on which to evaluate the lpdf + * @param mean The mean of the Gaussian distribution + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param prec_logdet The logarithm of the determinant of the precision + * matrix + * @return The evaluation of the lpdf */ Eigen::VectorXd multi_normal_prec_lpdf_grid(const Eigen::MatrixXd &data, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, const double prec_logdet); -/* +/** * Returns a pseudorandom multivariate normal random variable with diagonal * covariance matrix * - * @param mean The mean of the Gaussian r.v. - * @param cov_diag The diagonal covariance matrix - * @rng Random number generator - * @return multivariate normal r.v. + * @param mean The mean of the Gaussian r.v. + * @param cov_diag The diagonal covariance matrix + * @param rng Random number generator + * @return Multivariate normal r.v. */ Eigen::VectorXd multi_normal_diag_rng( const Eigen::VectorXd &mean, const Eigen::DiagonalMatrix &cov_diag, std::mt19937_64 &rng); -/* +/** * Returns a pseudorandom multivariate normal random variable parametrized * through mean and Cholesky decomposition of precision matrix * - * @param mean The mean of the Gaussian r.v. - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @param rng Random number generator - * @return multivariate normal r.v. + * @param mean The mean of the Gaussian r.v. + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param rng Random number generator + * @return Multivariate normal r.v. */ Eigen::VectorXd multi_normal_prec_chol_rng( const Eigen::VectorXd &mean, const Eigen::LLT &prec_chol, std::mt19937_64 &rng); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution with the following covariance structure: * Sigma + Lambda * Lambda^T @@ -91,7 +94,6 @@ Eigen::VectorXd multi_normal_prec_chol_rng( * y^T*(Sigma + Lambda * Lambda^T)^{-1}*y = y^T*Sigma^{-1}*y - * ||wood_factor*y||^2 * - * * @param datum Point on which to evaluate the lpdf * @param mean The mean of the Gaussian distribution * @param sigma_diag_inverse The inverse of the diagonal of Sigma matrix @@ -106,7 +108,7 @@ double multi_normal_lpdf_woodbury_chol( const Eigen::DiagonalMatrix &sigma_diag_inverse, const Eigen::MatrixXd &wood_factor, const double &cov_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution with the following covariance structure: * Sigma + Lambda * Lambda^T @@ -116,41 +118,42 @@ double multi_normal_lpdf_woodbury_chol( * computation from being O(p^3) to being O(d^3 p) which gives a substantial * speedup when p >> d * - * @param datum Point on which to evaluate the lpdf - * @param mean The mean of the Gaussian distribution + * @param datum Point on which to evaluate the lpdf + * @param mean The mean of the Gaussian distribution * @param sigma_diag The diagonal of Sigma matrix * @param lambda Rectangular matrix in Woodbury Identity - * @return The evaluation of the lpdf + * @return The evaluation of the lpdf */ double multi_normal_lpdf_woodbury(const Eigen::VectorXd &datum, const Eigen::VectorXd &mean, const Eigen::VectorXd &sigma_diag, const Eigen::MatrixXd &lambda); -/* - * Returns the log-determinant of the matrix Lambda Lambda^T + Sigma +/** + * Returns the log-determinant of the matrix \f$ \Lambda\Lambda^T + \Sigma \f$ * and the 'wood_factor', i.e. * L^{-1} * Lambda^T * Sigma^{-1}, * where L is the (lower) Cholesky factor of * I + Lambda^T * Sigma^{-1} * Lambda * - * @param sigma_dag_inverse The inverse of the diagonal matrix Sigma + * @param sigma_diag_inverse The inverse of the diagonal matrix Sigma * @param lambda The matrix Lambda */ std::pair compute_wood_chol_and_logdet( const Eigen::DiagonalMatrix &sigma_diag_inverse, const Eigen::MatrixXd &lambda); -/* +/** * Evaluates the log probability density function of a multivariate Student's t * distribution on a single point * - * @param datum Point in which to evaluate the the lpdf - * @param df The degrees of freedom of the Student's t distribution - * @param mean The mean of the Student's t distribution - * @invscale_chol The (lower) Cholesky factor of the inverse scale matrix - * @prec_logdet The logarithm of the determinant of the inverse scale matrix - * @return The evaluation of the lpdf + * @param datum Point in which to evaluate the the lpdf + * @param df The degrees of freedom of the Student's t distribution + * @param mean The mean of the Student's t distribution + * @param invscale_chol The (lower) Cholesky factor of the inverse scale matrix + * @param scale_logdet The logarithm of the determinant of the inverse scale + * matrix + * @return The evaluation of the lpdf */ double multi_student_t_invscale_lpdf(const Eigen::VectorXd &datum, const double df, @@ -158,28 +161,29 @@ double multi_student_t_invscale_lpdf(const Eigen::VectorXd &datum, const Eigen::MatrixXd &invscale_chol, const double scale_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Student's t * distribution on multiple points * - * @param data Grid of points (by row) on which to evaluate the lpdf - * @param df The degrees of freedom of the Student's t distribution - * @param mean The mean of the Student's t distribution - * @invscale_chol The (lower) Cholesky factor of the inverse scale matrix - * @prec_logdet The logarithm of the determinant of the inverse scale matrix - * @return The evaluation of the lpdf + * @param data Grid of points (by row) on which to evaluate the lpdf + * @param df The degrees of freedom of the Student's t distribution + * @param mean The mean of the Student's t distribution + * @param invscale_chol The (lower) Cholesky factor of the inverse scale matrix + * @param scale_logdet The logarithm of the determinant of the inverse scale + * matrix + * @return The evaluation of the lpdf */ Eigen::VectorXd multi_student_t_invscale_lpdf_grid( const Eigen::MatrixXd &data, const double df, const Eigen::VectorXd &mean, const Eigen::MatrixXd &invscale_chol, const double scale_logdet); -/* +/** * Computes the L^2 distance between the univariate mixture of Gaussian - * densities p1(x) = \sum_{h=1}^m1 w1[h] N(x | mean1[h], var1[h]) and - * p2(x) = \sum_{h=1}^m2 w2[h] N(x | mean2[h], var2[h]) + * densities p1(x) = sum_{h=1}^m1 w1[h] N(x | mean1[h], var1[h]) and + * p2(x) = sum_{h=1}^m2 w2[h] N(x | mean2[h], var2[h]) * * The L^2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} */ double gaussian_mixture_dist(const Eigen::VectorXd &means1, const Eigen::VectorXd &vars1, @@ -188,13 +192,13 @@ double gaussian_mixture_dist(const Eigen::VectorXd &means1, const Eigen::VectorXd &vars2, const Eigen::VectorXd &weights2); -/* +/** * Computes the L^2 distance between the multivariate mixture of Gaussian - * densities p1(x) = \sum_{h=1}^m1 w1[h] N(x | mean1[h], Prec[1]^{-1}) and - * p2(x) = \sum_{h=1}^m2 w2[h] N(x | mean2[h], Prec2[h]^{-1}) + * densities p1(x) = sum_{h=1}^m1 w1[h] N(x | mean1[h], Prec[1]^{-1}) and + * p2(x) = sum_{h=1}^m2 w2[h] N(x | mean2[h], Prec2[h]^{-1}) * * The L^2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} */ double gaussian_mixture_dist(const std::vector &means1, const std::vector &precs1, @@ -203,11 +207,11 @@ double gaussian_mixture_dist(const std::vector &means1, const std::vector &precs2, const Eigen::VectorXd &weights2); -/* +/** * Computes the L^2 distance between the mixture of Gaussian * densities p(x) and q(x). These could be either univariate or multivariate. * The L2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} * * @param clus1, clus2 Cluster-specific parameters of the mix. densities * @param weights1, weights2 Weigths of the mixture densities diff --git a/src/utils/eigen_utils.h b/src/utils/eigen_utils.h index 46732da78..2307b7f6f 100644 --- a/src/utils/eigen_utils.h +++ b/src/utils/eigen_utils.h @@ -1,12 +1,13 @@ #ifndef BAYESMIX_SRC_UTILS_EIGEN_UTILS_H_ #define BAYESMIX_SRC_UTILS_EIGEN_UTILS_H_ -#include +#include #include -//! This file implements a few methods to manipulate groups of matrices, mainly -//! by joining different objects, as well as additional utilities for SPD -//! checking and grid creation. +//! @file eigen_utils.h +//! The `eigen_utils.h` file implements a few methods to manipulate groups of +//! matrices, mainly by joining different objects, as well as additional +//! utilities for SPD checking and grid creation. namespace bayesmix { //! Concatenates a vector of Eigen matrices along the rows diff --git a/src/utils/io_utils.cc b/src/utils/io_utils.cc index 2a953368b..0985f87ad 100644 --- a/src/utils/io_utils.cc +++ b/src/utils/io_utils.cc @@ -1,8 +1,8 @@ #include "io_utils.h" -#include #include #include +#include bool bayesmix::check_file_is_writeable(const std::string &filename) { std::ofstream ofstr; diff --git a/src/utils/io_utils.h b/src/utils/io_utils.h index d0b64de80..b9c4231a6 100644 --- a/src/utils/io_utils.h +++ b/src/utils/io_utils.h @@ -1,10 +1,11 @@ #ifndef BAYESMIX_UTILS_IO_UTILS_H_ #define BAYESMIX_UTILS_IO_UTILS_H_ -#include +#include -//! This file implements basic input-output utilities for Eigen matrices from -//! and to text files. +//! @file io_utils.h +//! The `io_utils.h` file implements basic input-output utilities for Eigen +//! matrices from and to text files. namespace bayesmix { //! Checks whether the given file is available for writing diff --git a/src/utils/proto_utils.cc b/src/utils/proto_utils.cc index 7c3d3203f..5078ed7b8 100644 --- a/src/utils/proto_utils.cc +++ b/src/utils/proto_utils.cc @@ -3,8 +3,8 @@ #include #include -#include #include +#include #include "matrix.pb.h" diff --git a/src/utils/proto_utils.h b/src/utils/proto_utils.h index c03199a23..cb8c3333d 100644 --- a/src/utils/proto_utils.h +++ b/src/utils/proto_utils.h @@ -1,15 +1,16 @@ #ifndef BAYESMIX_UTILS_PROTO_UTILS_H_ #define BAYESMIX_UTILS_PROTO_UTILS_H_ -#include +#include #include "matrix.pb.h" -//! This file implements a few useful functions to manipulate Protobuf objects. -//! For instance, this library implements its own version of vectors and -//! matrices, and the functions implemented here convert from these types to -//! the Eigen ones and viceversa. One can also read a Protobuf from a text -//! file. This is mostly useful for algorithm configuration files. +//! @file proto_utils.h +//! The `proto_utils.h` file implements a few useful functions to manipulate +//! Protobuf objects. For instance, this library implements its own version of +//! vectors and matrices, and the functions implemented here convert from these +//! types to the Eigen ones and viceversa. One can also read a Protobuf from a +//! text file. This is mostly useful for algorithm configuration files. namespace bayesmix { diff --git a/src/utils/rng.h b/src/utils/rng.h index 8fc6e6264..fc71a2f69 100644 --- a/src/utils/rng.h +++ b/src/utils/rng.h @@ -3,7 +3,8 @@ #include -//! Simple Random Number Generation class wrapper. +//! @file rng.h +//! The `rng.h` file defines a simple Random Number Generation class wrapper. //! This class wraps the C++ standard RNG object and allows the use of any RNG //! seed. It is implemented as a singleton, so that every object used in the //! library has access to the same exact RNG engine. diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index dd2a016b1..20f7806ae 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,9 +18,10 @@ FetchContent_MakeAvailable(googletest) add_executable(test_bayesmix $ write_proto.cc proto_utils.cc + likelihoods.cc + prior_models.cc hierarchies.cc lpdf.cc - priors.cc eigen_utils.cc distributions.cc semi_hdp.cc @@ -28,6 +29,7 @@ add_executable(test_bayesmix $ runtime.cc rng.cc logit_sb.cc + gradient.cc ) target_include_directories(test_bayesmix PUBLIC ${INCLUDE_PATHS}) diff --git a/test/collectors.cc b/test/collectors.cc index 38da70caa..7f03e96c5 100644 --- a/test/collectors.cc +++ b/test/collectors.cc @@ -1,6 +1,6 @@ #include -#include +#include #include #include "matrix.pb.h" diff --git a/test/distributions.cc b/test/distributions.cc index 38e4d38d5..3705a5833 100644 --- a/test/distributions.cc +++ b/test/distributions.cc @@ -2,10 +2,11 @@ #include -#include #include +#include #include +#include "src/hierarchies/likelihoods/states/includes.h" #include "src/utils/rng.h" TEST(mix_dist, 1) { @@ -88,7 +89,6 @@ TEST(student_t, optimized) { Eigen::VectorXd x = Eigen::VectorXd::Ones(5); double lpdf_stan = stan::math::multi_student_t_lpdf(x, df, mean, sigma); - // std::cout << "lpdf_stan: " << lpdf_stan << std::endl; Eigen::MatrixXd sigma_inv = stan::math::inverse_spd(sigma); Eigen::MatrixXd sigma_inv_chol = @@ -99,8 +99,6 @@ TEST(student_t, optimized) { double our_lpdf = bayesmix::multi_student_t_invscale_lpdf( x, df, mean, sigma_inv_chol, logdet); - // std::cout << "our_lpdf: " << our_lpdf << std::endl; - ASSERT_LE(std::abs(our_lpdf - lpdf_stan), 0.001); } @@ -152,17 +150,19 @@ TEST(mult_normal, lpdf_grid) { Eigen::MatrixXd tmp = Eigen::MatrixXd::Random(dim + 1, dim); Eigen::MatrixXd prec = tmp.transpose() * tmp + Eigen::MatrixXd::Identity(dim, dim); - Eigen::MatrixXd prec_chol = Eigen::LLT(prec).matrixU(); - Eigen::VectorXd diag = prec_chol.diagonal(); - double prec_logdet = 2 * log(diag.array()).sum(); + State::MultiLS state; + state.set_from_constrained(mean, prec); Eigen::VectorXd lpdfs = bayesmix::multi_normal_prec_lpdf_grid( - data, mean, prec_chol, prec_logdet); + data, state.mean, state.prec_chol, state.prec_logdet); for (int i = 0; i < 20; i++) { - double curr = bayesmix::multi_normal_prec_lpdf(data.row(i), mean, - prec_chol, prec_logdet); + double curr = bayesmix::multi_normal_prec_lpdf( + data.row(i), state.mean, state.prec_chol, state.prec_logdet); + double curr2 = stan::math::multi_normal_prec_lpdf(data.row(i), state.mean, + state.prec); ASSERT_DOUBLE_EQ(curr, lpdfs(i)); + ASSERT_DOUBLE_EQ(curr, curr2); } } diff --git a/test/eigen_utils.cc b/test/eigen_utils.cc index 5fd5b23b7..c02f11983 100644 --- a/test/eigen_utils.cc +++ b/test/eigen_utils.cc @@ -2,7 +2,7 @@ #include -#include +#include #include TEST(vstack, 1) { diff --git a/test/gradient.cc b/test/gradient.cc new file mode 100644 index 000000000..3ca4c1a13 --- /dev/null +++ b/test/gradient.cc @@ -0,0 +1,50 @@ +#include + +#include + +class fbase { + public: + virtual double lpdf(const Eigen::VectorXd& x) = 0; +}; + +class f1 : public fbase { + protected: + double y; + + public: + f1() = default; + f1(double y) : y(y) {} + + template + T lpdf(const Eigen::Matrix& x) const { + return 0.5 * x.squaredNorm() * y; + } + + double lpdf(const Eigen::Matrix& x) { + return this->lpdf(x); + } +}; + +template +struct target_lpdf { + F f; + + template + T operator()(const Eigen::Matrix& x) const { + return f.lpdf(x); + } +}; + +TEST(gradient, quadratic_function) { + Eigen::VectorXd out; + Eigen::VectorXd x(5); + x << 1.0, 2.0, 3.0, 4.0, 5.0; + target_lpdf target_function; + target_function.f = f1(5.0); + double y; + stan::math::gradient(target_function, x, y, out); + + for (int i = 0; i < 5; i++) { + ASSERT_DOUBLE_EQ(out(i), 5 * x(i)); + } +} diff --git a/test/hierarchies.cc b/test/hierarchies.cc index c48575c34..0a4a94c25 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include "algorithm_state.pb.h" #include "ls_state.pb.h" @@ -9,11 +9,12 @@ #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" #include "src/hierarchies/nnw_hierarchy.h" +#include "src/hierarchies/nnxig_hierarchy.h" #include "src/includes.h" #include "src/utils/proto_utils.h" #include "src/utils/rng.h" -TEST(nnighierarchy, draw) { +TEST(nnig_hierarchy, draw) { auto hier = std::make_shared(); bayesmix::NNIGPrior prior; double mu0 = 5.0; @@ -39,7 +40,7 @@ TEST(nnighierarchy, draw) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -TEST(nnighierarchy, sample_given_data) { +TEST(nnig_hierarchy, sample_given_data) { auto hier = std::make_shared(); bayesmix::NNIGPrior prior; double mu0 = 5.0; @@ -70,7 +71,7 @@ TEST(nnighierarchy, sample_given_data) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -TEST(nnwhierarchy, draw) { +TEST(nnw_hierarchy, draw) { auto hier = std::make_shared(); bayesmix::NNWPrior prior; Eigen::Vector2d mu0; @@ -101,7 +102,7 @@ TEST(nnwhierarchy, draw) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -TEST(nnwhierarchy, sample_given_data) { +TEST(nnw_hierarchy, sample_given_data) { auto hier = std::make_shared(); bayesmix::NNWPrior prior; Eigen::Vector2d mu0; @@ -136,6 +137,33 @@ TEST(nnwhierarchy, sample_given_data) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } +TEST(nnw_hierarchy, no_unconstrained_lpdf) { + // Initialize hierarchy + auto hier = std::make_shared(); + bayesmix::NNWPrior prior; + Eigen::Vector2d mu0; + mu0 << 5.5, 5.5; + bayesmix::Vector mu0_proto; + bayesmix::to_proto(mu0, &mu0_proto); + double lambda0 = 0.2; + double nu0 = 5.0; + Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; + bayesmix::Matrix tau0_proto; + bayesmix::to_proto(tau0, &tau0_proto); + *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; + prior.mutable_fixed_values()->set_var_scaling(lambda0); + prior.mutable_fixed_values()->set_deg_free(nu0); + *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; + hier->get_mutable_prior()->CopyFrom(prior); + hier->initialize(); + + // Check exeption handling in case unconstrained lpdfs are not implemented + auto state_uc = hier->get_state().get_unconstrained(); + EXPECT_ANY_THROW( + hier->get_likelihood()->cluster_lpdf_from_unconstrained(state_uc)); + EXPECT_ANY_THROW(hier->get_prior()->lpdf_from_unconstrained(state_uc)); +} + TEST(lin_reg_uni_hierarchy, state_read_write) { Eigen::Vector2d beta; beta << 2, -1; @@ -164,10 +192,8 @@ TEST(lin_reg_uni_hierarchy, state_read_write) { TEST(lin_reg_uni_hierarchy, misc) { // Build data - int n = 5; - int dim = 2; - Eigen::Vector2d beta_true; - beta_true << 10.0, 10.0; + int n = 5, dim = 2; + Eigen::Vector2d beta_true({10.0, 10.0}); Eigen::MatrixXd cov = Eigen::MatrixXd::Random(n, dim); // each in U[-1,1] double sigma2 = 1.0; Eigen::VectorXd data(n); @@ -175,25 +201,26 @@ TEST(lin_reg_uni_hierarchy, misc) { for (int i = 0; i < n; i++) { data(i) = stan::math::normal_rng(cov.row(i).dot(beta_true), sigma2, rng); } + // Initialize objects LinRegUniHierarchy hier; bayesmix::LinRegUniPrior prior; + // Create prior parameters Eigen::Vector2d beta0 = 0 * beta_true; - bayesmix::Vector beta0_proto; - bayesmix::to_proto(beta0, &beta0_proto); auto Lambda0 = Eigen::Matrix2d::Identity(); - bayesmix::Matrix Lambda0_proto; - bayesmix::to_proto(Lambda0, &Lambda0_proto); double a0 = 2.0; double b0 = 1.0; + // Initialize hierarchy - *prior.mutable_fixed_values()->mutable_mean() = beta0_proto; - *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; + bayesmix::to_proto(beta0, prior.mutable_fixed_values()->mutable_mean()); + bayesmix::to_proto(Lambda0, + prior.mutable_fixed_values()->mutable_var_scaling()); prior.mutable_fixed_values()->set_shape(a0); prior.mutable_fixed_values()->set_scale(b0); hier.get_mutable_prior()->CopyFrom(prior); hier.initialize(); + // Extract hypers for reading test bayesmix::AlgorithmState::HierarchyHypers out; hier.write_hypers_to_proto(&out); @@ -202,10 +229,12 @@ TEST(lin_reg_uni_hierarchy, misc) { bayesmix::to_eigen(out.lin_reg_uni_state().var_scaling())); ASSERT_EQ(a0, out.lin_reg_uni_state().shape()); ASSERT_EQ(b0, out.lin_reg_uni_state().scale()); + // Add data for (int i = 0; i < n; i++) { hier.add_datum(i, data.row(i), false, cov.row(i)); } + // Check summary statistics // for (int i = 0; i < dim; i++) { // for (int j = 0; j < dim; j++) { @@ -214,6 +243,7 @@ TEST(lin_reg_uni_hierarchy, misc) { // } // ASSERT_DOUBLE_EQ(hier.get_mixed_prod()(i), (cov.transpose() * data)(i)); // } + // Compute and check posterior values hier.sample_full_cond(); auto state = hier.get_state(); @@ -222,25 +252,13 @@ TEST(lin_reg_uni_hierarchy, misc) { } } -TEST(fahierarchy, draw) { - auto hier = std::make_shared(); - bayesmix::FAPrior prior; - Eigen::VectorXd mutilde(4); - mutilde << 3.0, 3.0, 4.0, 1.0; - bayesmix::Vector mutilde_proto; - bayesmix::to_proto(mutilde, &mutilde_proto); - int q = 2; - double phi = 1.0; - double alpha0 = 5.0; - Eigen::VectorXd beta(4); - beta << 3.0, 3.0, 2.0, 2.1; - bayesmix::Vector beta_proto; - bayesmix::to_proto(beta, &beta_proto); - *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; - prior.mutable_fixed_values()->set_phi(phi); - prior.mutable_fixed_values()->set_alpha0(alpha0); - prior.mutable_fixed_values()->set_q(q); - *prior.mutable_fixed_values()->mutable_beta() = beta_proto; +TEST(nnxig_hierarchy, draw) { + auto hier = std::make_shared(); + bayesmix::NNxIGPrior prior; + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var(1.2); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); hier->get_mutable_prior()->CopyFrom(prior); hier->initialize(); @@ -256,22 +274,47 @@ TEST(fahierarchy, draw) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -TEST(fahierarchy, draw_auto) { +TEST(nnxig_hierarchy, sample_given_data) { + auto hier = std::make_shared(); + bayesmix::NNxIGPrior prior; + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var(1.2); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); + hier->get_mutable_prior()->CopyFrom(prior); + + hier->initialize(); + + Eigen::VectorXd datum(1); + datum << 4.5; + + auto hier2 = hier->clone(); + hier2->add_datum(0, datum, false); + hier2->sample_full_cond(); + + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); + + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} + +TEST(fa_hierarchy, draw) { auto hier = std::make_shared(); bayesmix::FAPrior prior; - Eigen::VectorXd mutilde(0); + Eigen::VectorXd mutilde(4); + mutilde << 3.0, 3.0, 4.0, 1.0; bayesmix::Vector mutilde_proto; bayesmix::to_proto(mutilde, &mutilde_proto); int q = 2; double phi = 1.0; double alpha0 = 5.0; - Eigen::VectorXd beta(0); + Eigen::VectorXd beta(4); + beta << 3.0, 3.0, 2.0, 2.1; bayesmix::Vector beta_proto; bayesmix::to_proto(beta, &beta_proto); - Eigen::MatrixXd dataset(5, 5); - dataset << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 20, 1, 5, 7, 8, 9; - hier->set_dataset(&dataset); *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; prior.mutable_fixed_values()->set_phi(phi); prior.mutable_fixed_values()->set_alpha0(alpha0); @@ -289,11 +332,10 @@ TEST(fahierarchy, draw_auto) { hier->write_state_to_proto(clusval); hier2->write_state_to_proto(clusval2); - ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()) - << clusval->DebugString() << clusval2->DebugString(); + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -TEST(fahierarchy, sample_given_data) { +TEST(fa_hierarchy, sample_given_data) { auto hier = std::make_shared(); bayesmix::FAPrior prior; Eigen::VectorXd mutilde(4); diff --git a/test/likelihoods.cc b/test/likelihoods.cc new file mode 100644 index 000000000..4cb2a7d5e --- /dev/null +++ b/test/likelihoods.cc @@ -0,0 +1,391 @@ +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "ls_state.pb.h" +#include "src/hierarchies/likelihoods/laplace_likelihood.h" +#include "src/hierarchies/likelihoods/multi_norm_likelihood.h" +#include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h" +#include "src/hierarchies/likelihoods/uni_norm_likelihood.h" +#include "src/utils/proto_utils.h" +#include "src/utils/rng.h" + +TEST(uni_norm_likelihood, set_get_state) { + // Instance + auto like = std::make_shared(); + + // Prepare buffers + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state + state_.set_mean(5.23); + state_.set_var(1.02); + set_state_.mutable_uni_ls_state()->CopyFrom(state_); + + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(uni_norm_likelihood, add_remove_data) { + // Instance + auto like = std::make_shared(); + + // Add new datum to likelihood + Eigen::VectorXd datum(1); + datum << 5.0; + like->add_datum(0, datum); + + // Check if cardinality is augmented + ASSERT_EQ(like->get_card(), 1); + + // Remove datum from likelihood + like->remove_datum(0, datum); + + // Check if cardinality is reduced + ASSERT_EQ(like->get_card(), 0); +} + +TEST(uni_norm_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + state_.set_mean(5); + state_.set_var(1); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 4.5, 5.1, 2.5; + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} + +TEST(uni_norm_likelihood, eval_lpdf_unconstrained) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + double mean = 5; + double var = 1; + state_.set_mean(mean); + state_.set_var(var); + Eigen::VectorXd unconstrained_params(2); + unconstrained_params << mean, std::log(var); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 4.5, 5.1, 2.5; + double lpdf = 0.0; + for (int i = 0; i < data.size(); ++i) { + like->add_datum(i, data.row(i)); + lpdf += like->lpdf(data.row(i)); + } + + double clus_lpdf = + like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_DOUBLE_EQ(lpdf, clus_lpdf); + + unconstrained_params(0) = 4.0; + clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5); +} + +TEST(multi_ls_state, set_unconstrained) { + auto& rng = bayesmix::Rng::Instance().get(); + + State::MultiLS state; + auto mean = Eigen::VectorXd::Zero(5); + auto prec = + stan::math::wishart_rng(10, Eigen::MatrixXd::Identity(5, 5), rng); + state.mean = mean; + state.prec = prec; + Eigen::VectorXd unconstrained_state = state.get_unconstrained(); + + State::MultiLS state2; + state2.set_from_unconstrained(unconstrained_state); + ASSERT_TRUE((state.mean - state2.mean).squaredNorm() < 1e-5); + ASSERT_TRUE((state.prec - state2.prec).squaredNorm() < 1e-5); +} + +TEST(multi_norm_likelihood, set_get_state) { + // Instance + auto like = std::make_shared(); + + // Prepare buffers + bayesmix::MultiLSState state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state + Eigen::Vector2d mu = {5.5, 5.5}; // mu << 5.5, 5.5; + Eigen::Matrix2d prec = Eigen::Matrix2d::Identity(); + bayesmix::Vector mean_proto; + bayesmix::Matrix prec_proto; + bayesmix::to_proto(mu, &mean_proto); + bayesmix::to_proto(prec, &prec_proto); + set_state_.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); + set_state_.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); + set_state_.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom( + prec_proto); + + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(multi_norm_likelihood, add_remove_data) { + // Instance + auto like = std::make_shared(); + + // Add new datum to likelihood + Eigen::RowVectorXd datum(2); + datum << 5.5, 5.5; + like->add_datum(0, datum); + + // Check if cardinality is augmented + ASSERT_EQ(like->get_card(), 1); + + // Remove datum from likelihood + like->remove_datum(0, datum); + + // Check if cardinality is reduced + ASSERT_EQ(like->get_card(), 0); +} + +TEST(multi_norm_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::AlgorithmState::ClusterState clust_state_; + Eigen::Vector2d mu = {5.5, 5.5}; // mu << 5.5, 5.5; + Eigen::Matrix2d prec = Eigen::Matrix2d::Identity(); + bayesmix::Vector mean_proto; + bayesmix::Matrix prec_proto; + bayesmix::to_proto(mu, &mean_proto); + bayesmix::to_proto(prec, &prec_proto); + clust_state_.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); + clust_state_.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); + clust_state_.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom( + prec_proto); + like->set_state_from_proto(clust_state_); + + // Data matrix on which evaluate the likelihood + Eigen::MatrixXd data(3, 2); + data.row(0) << 4.5, 4.5; + data.row(1) << 5.1, 5.1; + data.row(2) << 2.5, 2.5; + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} + +TEST(uni_lin_reg_likelihood, set_get_state) { + // Instance + auto like = std::make_shared(); + + // Prepare buffers + bayesmix::LinRegUniLSState state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state + Eigen::Vector3d reg_coeffs; + reg_coeffs << 2.25, 0.22, -7.1; + bayesmix::Vector reg_coeffs_proto; + bayesmix::to_proto(reg_coeffs, ®_coeffs_proto); + state_.mutable_regression_coeffs()->CopyFrom(reg_coeffs_proto); + state_.set_var(1.02); + set_state_.mutable_lin_reg_uni_ls_state()->CopyFrom(state_); + + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(uni_lin_reg_likelihood, add_remove_data) { + // Instance + auto like = std::make_shared(); + + // Add new datum to likelihood + Eigen::VectorXd datum(1); + datum << 5.0; + like->add_datum(0, datum); + + // Check if cardinality is augmented + ASSERT_EQ(like->get_card(), 1); + + // Remove datum from likelihood + like->remove_datum(0, datum); + + // Check if cardinality is reduced + ASSERT_EQ(like->get_card(), 0); +} + +TEST(uni_lin_reg_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::LinRegUniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + Eigen::Vector3d reg_coeffs; + reg_coeffs << 2.25, 0.22, -7.1; + bayesmix::Vector reg_coeffs_proto; + bayesmix::to_proto(reg_coeffs, ®_coeffs_proto); + state_.mutable_regression_coeffs()->CopyFrom(reg_coeffs_proto); + state_.set_var(1.02); + clust_state_.mutable_lin_reg_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Generate data + Eigen::Vector3d data; + data << 4.5, 5.1, 2.5; + + // Generate random covariate matrix + Eigen::MatrixXd cov = + Eigen::MatrixXd::Random(data.size(), reg_coeffs.size()); + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data, cov); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data, cov); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} + +TEST(laplace_likelihood, set_get_state) { + // Instance + auto like = std::make_shared(); + + // Prepare buffers + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state + state_.set_mean(5.23); + state_.set_var(1.02); + set_state_.mutable_uni_ls_state()->CopyFrom(state_); + + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(laplace_likelihood, add_remove_data) { + // Instance + auto like = std::make_shared(); + + // Add new datum to likelihood + Eigen::VectorXd datum(1); + datum << 5.0; + like->add_datum(0, datum); + + // Check if cardinality is augmented + ASSERT_EQ(like->get_card(), 1); + + // Remove datum from likelihood + like->remove_datum(0, datum); + + // Check if cardinality is reduced + ASSERT_EQ(like->get_card(), 0); +} + +TEST(laplace_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + state_.set_mean(5); + state_.set_var(1); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 4.5, 5.1, 2.5; + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} + +TEST(laplace_likelihood, eval_lpdf_unconstrained) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + double mean = 5; + double var = 1; + state_.set_mean(mean); + state_.set_var(var); + Eigen::VectorXd unconstrained_params(2); + unconstrained_params << mean, std::log(var); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new data to likelihood + Eigen::MatrixXd data(3, 1); + data << 4.5, 5.1, 2.5; + double lpdf = 0.0; + for (int i = 0; i < data.size(); ++i) { + like->add_datum(i, data.row(i)); + lpdf += like->lpdf(data.row(i)); + } + + like->set_dataset(&data); + double clus_lpdf = + like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_DOUBLE_EQ(lpdf, clus_lpdf); + + unconstrained_params(0) = 3.0; + clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5); +} diff --git a/test/logit_sb.cc b/test/logit_sb.cc index 0033e6bf6..85405dc00 100644 --- a/test/logit_sb.cc +++ b/test/logit_sb.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include #include "src/hierarchies/abstract_hierarchy.h" diff --git a/test/lpdf.cc b/test/lpdf.cc index 5a97220a0..fe1cde610 100644 --- a/test/lpdf.cc +++ b/test/lpdf.cc @@ -1,13 +1,12 @@ #include -#include #include // lgamma, lmgamma #include +#include #include "algorithm_state.pb.h" #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" -#include "src/hierarchies/nnw_hierarchy.h" #include "src/utils/proto_utils.h" TEST(lpdf, nnig) { @@ -54,89 +53,6 @@ TEST(lpdf, nnig) { ASSERT_DOUBLE_EQ(sum, marg); } -// TEST(lpdf, nnw) { // TODO -// using namespace stan::math; -// NNWHierarchy hier; -// bayesmix::NNWPrior hier_prior; -// Eigen::Vector2d mu0; mu0 << 5.5, 5.5; -// bayesmix::Vector mu0_proto; -// bayesmix::to_proto(mu0, &mu0_proto); -// double lambda0 = 0.2; -// double nu0 = 5.0; -// Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; -// bayesmix::Matrix tau0_proto; -// bayesmix::to_proto(tau0, &tau0_proto); -// *hier_prior.mutable_fixed_values()->mutable_mean() = mu0_proto; -// hier_prior.mutable_fixed_values()->set_var_scaling(lambda0); -// hier_prior.mutable_fixed_values()->set_deg_free(nu0); -// *hier_prior.mutable_fixed_values()->mutable_scale() = tau0_proto; -// hier.set_prior(hier_prior); -// hier.initialize(); -// -// Eigen::VectorXd mu = mu0; -// Eigen::MatrixXd tau = lambda0 * Eigen::Matrix2d::Identity(); -// -// Eigen::RowVectorXd datum(2); -// datum << 4.5, 4.5; -// -// // Compute prior parameters -// Eigen::MatrixXd tau_pr = lambda0 * tau0; -// -// // Compute posterior parameters -// double mu_n = (lambda0 * mu0 + datum(0)) / (lambda0 + 1); -// double alpha_n = alpha0 + 0.5; -// double lambda_n = lambda0 + 1; -// double nu_n = nu0 + 0.5; -// Eigen::VectorXd mu_n = -// (lambda0 * mu0 + datum.transpose()) / (lambda0 + 1); -// Eigen::MatrixXd tau_temp = -// stan::math::inverse_spd(tau0) + (0.5 * lambda0 / (lambda0 + 1)) * -// (datum.transpose() - mu0) * -// (datum - mu0.transpose()); -// Eigen::MatrixXd tau_n = stan::math::inverse_spd(tau_temp); -// Eigen::MatrixXd tau_post = lambda_n * tau_n; -// -// // Compute pieces -// double prior1 = stan::math::wishart_lpdf(tau, nu0, tau0); -// double prior2 = stan::math::multi_normal_prec_lpdf(mu, mu0, tau_pr); -// double prior = prior1 + prior2; -// double like = hier.get_like_lpdf(datum); -// double post1 = stan::math::wishart_lpdf(tau, nu_n, tau_post); -// double post2 = stan::math::multi_normal_prec_lpdf(mu, mu0, tau_post); -// double post = post1 + post2; -// -// // Bayes: logmarg(x) = logprior(phi) + loglik(x|phi) - logpost(phi|x) -// double sum = prior + like - post; -// double marg = hier.prior_pred_lpdf(false, datum); -// -// // Compute logdet's -// Eigen::MatrixXd tauchol0 = -// Eigen::LLT(tau0).matrixL().transpose(); -// double logdet0 = 2 * log(tauchol0.diagonal().array()).sum(); -// Eigen::MatrixXd tauchol_n = -// Eigen::LLT(tau_n).matrixL().transpose(); -// double logdet_n = 2 * log(tauchol_n.diagonal().array()).sum(); -// -// // lmgamma(dim, x) -// int dim = 2; -// double marg_murphy = lmgamma(dim, 0.5 * nu_n) + 0.5 * nu_n * logdet_n + -// 0.5 * dim * log(lambda0) + dim * NEG_LOG_SQRT_TWO_PI -// - lmgamma(dim, 0.5 * nu0) - 0.5 * nu0 * logdet0 - 0.5 -// * dim * log(lambda_n); -// -// // std::cout << "prior1=" << prior1 << std::endl; -// // std::cout << "prior2=" << prior2 << std::endl; -// // std::cout << "prior =" << prior << std::endl; -// // std::cout << "like =" << like << std::endl; -// // std::cout << "post1 =" << post1 << std::endl; -// // std::cout << "post2 =" << post2 << std::endl; -// // std::cout << "post =" << post << std::endl; -// std::cout << "sum =" << sum << std::endl; -// std::cout << "marg =" << marg << std::endl; -// std::cout << "murphy=" << marg_murphy << std::endl; -// ASSERT_DOUBLE_EQ(marg, marg_murphy); -// } - TEST(lpdf, lin_reg_uni) { // Create hierarchy objects LinRegUniHierarchy hier; diff --git a/test/prior_models.cc b/test/prior_models.cc new file mode 100644 index 000000000..be039e4f8 --- /dev/null +++ b/test/prior_models.cc @@ -0,0 +1,482 @@ +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "hierarchy_prior.pb.h" +#include "src/hierarchies/priors/mnig_prior_model.h" +#include "src/hierarchies/priors/nig_prior_model.h" +#include "src/hierarchies/priors/nw_prior_model.h" +#include "src/hierarchies/priors/nxig_prior_model.h" +#include "src/utils/proto_utils.h" + +TEST(nig_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::NIGDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + // Prepare hypers + hypers_.set_mean(5.0); + hypers_.set_var_scaling(0.1); + hypers_.set_shape(4.0); + hypers_.set_scale(3.0); + set_state_.mutable_nnig_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(nig_prior_model, fixed_values_prior) { + // Prepare buffers + bayesmix::NNIGPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var_scaling(0.1); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnig_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnig_state().DebugString()); + } +} + +TEST(nig_prior_model, normal_mean_prior) { + // Prepare buffers + bayesmix::NNIGPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + + // Set Normal prior on the mean + double mu00 = 0.5; + prior.mutable_normal_mean_prior()->mutable_mean_prior()->set_mean(mu00); + prior.mutable_normal_mean_prior()->mutable_mean_prior()->set_var(1.02); + prior.mutable_normal_mean_prior()->set_var_scaling(0.1); + prior.mutable_normal_mean_prior()->set_shape(2.0); + prior.mutable_normal_mean_prior()->set_scale(2.0); + + // Prepare some fictional states + std::vector states(4); + for (int i = 0; i < states.size(); i++) { + double mean = 9.0 + i; + states[i].mutable_uni_ls_state()->set_mean(mean); + states[i].mutable_uni_ls_state()->set_var(1.0); + } + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Update hypers in light of current states + prior_model->update_hypers(states); + prior_model->write_hypers_to_proto(&prior_out); + double mean_out = prior_out.nnig_state().mean(); + + // Check + ASSERT_GT(mean_out, mu00); +} + +TEST(nig_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + hypers_proto.mutable_nnig_state()->set_mean(5.0); + hypers_proto.mutable_nnig_state()->set_var_scaling(0.1); + hypers_proto.mutable_nnig_state()->set_shape(4.0); + hypers_proto.mutable_nnig_state()->set_scale(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(); + auto state2 = prior->sample(); + + // Check if they coincides + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); +} + +TEST(nxig_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::NxIGDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + // Prepare hypers + hypers_.set_mean(5.0); + hypers_.set_var(1.2); + hypers_.set_shape(4.0); + hypers_.set_scale(3.0); + set_state_.mutable_nnxig_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(nxig_prior_model, fixed_values_prior) { + // Prepare buffers + bayesmix::NNxIGPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var(1.2); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnxig_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnxig_state().DebugString()); + } +} + +TEST(nxig_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + // bool use_post_hypers = true; + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + hypers_proto.mutable_nnxig_state()->set_mean(5.0); + hypers_proto.mutable_nnxig_state()->set_var(1.2); + hypers_proto.mutable_nnxig_state()->set_shape(4.0); + hypers_proto.mutable_nnxig_state()->set_scale(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(); + auto state2 = prior->sample(); + + // Check if they coincides + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); +} + +TEST(nig_prior_model, unconstrained_lpdf) { + // Instance + auto prior = std::make_shared(); + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + hypers_proto.mutable_nnig_state()->set_mean(5.0); + hypers_proto.mutable_nnig_state()->set_var_scaling(0.1); + hypers_proto.mutable_nnig_state()->set_shape(4.0); + hypers_proto.mutable_nnig_state()->set_scale(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + double mean = 0.0; + double var = 5.0; + + bayesmix::AlgorithmState::ClusterState state; + state.mutable_uni_ls_state()->set_mean(mean); + state.mutable_uni_ls_state()->set_var(var); + Eigen::VectorXd unconstrained_params(2); + unconstrained_params << mean, std::log(var); + + ASSERT_DOUBLE_EQ( + prior->lpdf(state) + std::log(std::exp(unconstrained_params(1))), + prior->lpdf_from_unconstrained(unconstrained_params)); +} + +TEST(nw_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::NWDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + bayesmix::Vector mean_proto; + bayesmix::Matrix scale_proto; + bayesmix::to_proto(Eigen::Vector2d({5.5, 5.5}), &mean_proto); + bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale_proto); + + // Prepare hypers + hypers_.mutable_mean()->CopyFrom(mean_proto); + hypers_.set_deg_free(4); + hypers_.set_var_scaling(0.1); + hypers_.mutable_scale()->CopyFrom(scale_proto); + set_state_.mutable_nnw_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(nw_prior_model, fixed_values_prior) { + // Prepare buffers + bayesmix::NNWPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior + bayesmix::Vector mean_proto; + bayesmix::Matrix scale_proto; + bayesmix::to_proto(Eigen::Vector2d({5.5, 5.5}), &mean_proto); + bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale_proto); + prior.mutable_fixed_values()->mutable_mean()->CopyFrom(mean_proto); + prior.mutable_fixed_values()->set_var_scaling(0.1); + prior.mutable_fixed_values()->set_deg_free(4); + prior.mutable_fixed_values()->mutable_scale()->CopyFrom(scale_proto); + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnw_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnw_state().DebugString()); + } +} + +TEST(nw_prior_model, normal_mean_prior) { + // Prepare buffers + bayesmix::NNWPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + + // Set Normal prior on the mean + Eigen::Vector2d mu00 = Eigen::Vector2d::Zero(); + Eigen::Matrix2d Sigma00 = Eigen::Matrix2d::Identity(); + bayesmix::Vector mu00_proto; + bayesmix::Matrix Sigma00_proto, scale_proto; + bayesmix::to_proto(mu00, &mu00_proto); + bayesmix::to_proto(Sigma00, &Sigma00_proto); + bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale_proto); + prior.mutable_normal_mean_prior() + ->mutable_mean_prior() + ->mutable_mean() + ->CopyFrom(mu00_proto); + prior.mutable_normal_mean_prior() + ->mutable_mean_prior() + ->mutable_var() + ->CopyFrom(Sigma00_proto); + prior.mutable_normal_mean_prior()->set_var_scaling(0.1); + prior.mutable_normal_mean_prior()->set_deg_free(4); + prior.mutable_normal_mean_prior()->mutable_scale()->CopyFrom(scale_proto); + + // Prepare some fictional states + std::vector states(4); + for (int i = 0; i < states.size(); i++) { + Eigen::Vector2d mean = (9.0 + i) * Eigen::Vector2d::Ones(); + bayesmix::Vector tmp; + bayesmix::to_proto(mean, &tmp); + states[i].mutable_multi_ls_state()->mutable_mean()->CopyFrom(tmp); + states[i].mutable_multi_ls_state()->mutable_prec()->CopyFrom(scale_proto); + states[i].mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom( + scale_proto); + } + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Update hypers in light of current states + prior_model->update_hypers(states); + prior_model->write_hypers_to_proto(&prior_out); + Eigen::Vector2d mean_out = bayesmix::to_eigen(prior_out.nnw_state().mean()); + + // Check + for (size_t i = 0; i < mu00.size(); i++) { + ASSERT_GT(mean_out(i), mu00(i)); + } +} + +TEST(nw_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + // bool use_post_hypers = true; + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + bayesmix::Vector mean; + bayesmix::Matrix scale; + bayesmix::to_proto(Eigen::Vector2d({5.2, 5.2}), &mean); + bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale); + hypers_proto.mutable_nnw_state()->mutable_mean()->CopyFrom(mean); + hypers_proto.mutable_nnw_state()->set_var_scaling(0.1); + hypers_proto.mutable_nnw_state()->set_deg_free(4); + hypers_proto.mutable_nnw_state()->mutable_scale()->CopyFrom(scale); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(); + auto state2 = prior->sample(); + + // Check if they coincides + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); +} + +TEST(mnig_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::MultiNormalIGDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + // Prepare hypers + Eigen::Vector2d mean({2.0, 2.0}); + bayesmix::to_proto(mean, hypers_.mutable_mean()); + Eigen::Matrix2d var_scaling = Eigen::Matrix2d::Identity(); + bayesmix::to_proto(var_scaling, hypers_.mutable_var_scaling()); + hypers_.set_shape(4.0); + hypers_.set_scale(3.0); + set_state_.mutable_lin_reg_uni_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(mnig_prior_model, fixed_values_prior) { + // Prepare buffers + bayesmix::LinRegUniPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior + Eigen::Vector2d mean({2.0, 2.0}); + bayesmix::to_proto(mean, prior.mutable_fixed_values()->mutable_mean()); + Eigen::Matrix2d var_scaling = Eigen::Matrix2d::Identity(); + bayesmix::to_proto(var_scaling, + prior.mutable_fixed_values()->mutable_var_scaling()); + prior.mutable_fixed_values()->set_shape(4.0); + prior.mutable_fixed_values()->set_scale(3.0); + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.lin_reg_uni_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.lin_reg_uni_state().DebugString()); + } +} + +TEST(mnig_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + // bool use_post_hypers = true; + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + Eigen::Vector2d mean({5.0, 5.0}); + bayesmix::to_proto(mean, + hypers_proto.mutable_lin_reg_uni_state()->mutable_mean()); + Eigen::Matrix2d var_scaling = Eigen::Matrix2d::Identity(); + bayesmix::to_proto( + var_scaling, + hypers_proto.mutable_lin_reg_uni_state()->mutable_var_scaling()); + hypers_proto.mutable_lin_reg_uni_state()->set_shape(4.0); + hypers_proto.mutable_lin_reg_uni_state()->set_scale(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(); + auto state2 = prior->sample(); + + // Check if they coincides + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); +} diff --git a/test/priors.cc b/test/priors.cc deleted file mode 100644 index 6ad84a843..000000000 --- a/test/priors.cc +++ /dev/null @@ -1,126 +0,0 @@ -#include -#include - -#include - -#include "algorithm_state.pb.h" -#include "src/hierarchies/nnig_hierarchy.h" -#include "src/hierarchies/nnw_hierarchy.h" -#include "src/mixings/dirichlet_mixing.h" -#include "src/utils/proto_utils.h" - -TEST(mixing, fixed_value) { - DirichletMixing mix; - bayesmix::DPPrior prior; - double m = 2.0; - prior.mutable_fixed_value()->set_totalmass(m); - double m_state = prior.fixed_value().totalmass(); - ASSERT_DOUBLE_EQ(m, m_state); - mix.get_mutable_prior()->CopyFrom(prior); - mix.initialize(); - double m_mix = mix.get_state().totalmass; - ASSERT_DOUBLE_EQ(m, m_mix); - - std::vector> hiers(100); - unsigned int n_data = 1000; - mix.update_state(hiers, std::vector(n_data)); - double m_mix_after = mix.get_state().totalmass; - ASSERT_DOUBLE_EQ(m, m_mix_after); -} - -TEST(mixing, gamma_prior) { - DirichletMixing mix; - bayesmix::DPPrior prior; - double alpha = 1.0; - double beta = 2.0; - double m_prior = alpha / beta; - prior.mutable_gamma_prior()->mutable_totalmass_prior()->set_shape(alpha); - prior.mutable_gamma_prior()->mutable_totalmass_prior()->set_rate(beta); - mix.get_mutable_prior()->CopyFrom(prior); - mix.initialize(); - double m_mix = mix.get_state().totalmass; - ASSERT_DOUBLE_EQ(m_prior, m_mix); - - std::vector> hiers(100); - unsigned int n_data = 1000; - mix.update_state(hiers, std::vector(n_data)); - double m_mix_after = mix.get_state().totalmass; - - std::cout << " after = " << m_mix_after << std::endl; - ASSERT_TRUE(m_mix_after > m_mix); -} - -TEST(hierarchies, fixed_values) { - bayesmix::NNIGPrior prior; - bayesmix::AlgorithmState::HierarchyHypers prior_out; - prior.mutable_fixed_values()->set_mean(5.0); - prior.mutable_fixed_values()->set_var_scaling(0.1); - prior.mutable_fixed_values()->set_shape(2.0); - prior.mutable_fixed_values()->set_scale(2.0); - - auto hier = std::make_shared(); - hier->get_mutable_prior()->CopyFrom(prior); - hier->initialize(); - - std::vector> unique_values; - std::vector states; - - // Check equality before update - unique_values.push_back(hier); - for (size_t i = 1; i < 4; i++) { - unique_values.push_back(hier->clone()); - unique_values[i]->write_hypers_to_proto(&prior_out); - ASSERT_EQ(prior.fixed_values().DebugString(), - prior_out.nnig_state().DebugString()); - } - - // Check equality after update - unique_values[0]->update_hypers(states); - unique_values[0]->write_hypers_to_proto(&prior_out); - for (size_t i = 1; i < 4; i++) { - unique_values[i]->write_hypers_to_proto(&prior_out); - ASSERT_EQ(prior.fixed_values().DebugString(), - prior_out.nnig_state().DebugString()); - } -} - -TEST(hierarchies, normal_mean_prior) { - bayesmix::NNWPrior prior; - bayesmix::AlgorithmState::HierarchyHypers prior_out; - Eigen::Vector2d mu00; - mu00 << 0.0, 0.0; - auto ident = Eigen::Matrix2d::Identity(); - - prior.mutable_normal_mean_prior()->set_var_scaling(0.1); - bayesmix::to_proto( - mu00, - prior.mutable_normal_mean_prior()->mutable_mean_prior()->mutable_mean()); - bayesmix::to_proto( - ident, - prior.mutable_normal_mean_prior()->mutable_mean_prior()->mutable_var()); - prior.mutable_normal_mean_prior()->set_deg_free(2.0); - bayesmix::to_proto(ident, - prior.mutable_normal_mean_prior()->mutable_scale()); - - std::vector states(4); - for (int i = 0; i < states.size(); i++) { - double mean = 9.0 + i; - Eigen::Vector2d vec; - vec << mean, mean; - bayesmix::to_proto(vec, - states[i].mutable_multi_ls_state()->mutable_mean()); - bayesmix::to_proto(ident, - states[i].mutable_multi_ls_state()->mutable_prec()); - } - - NNWHierarchy hier; - hier.get_mutable_prior()->CopyFrom(prior); - hier.initialize(); - - hier.update_hypers(states); - hier.write_hypers_to_proto(&prior_out); - Eigen::Vector2d mean_out = bayesmix::to_eigen(prior_out.nnw_state().mean()); - std::cout << " after = " << mean_out(0) << " " << mean_out(1) - << std::endl; - assert(mu00(0) < mean_out(0) && mu00(1) < mean_out(1)); -} diff --git a/test/proto_utils.cc b/test/proto_utils.cc index baf904824..1ca53d155 100644 --- a/test/proto_utils.cc +++ b/test/proto_utils.cc @@ -2,7 +2,7 @@ #include -#include +#include #include "matrix.pb.h" diff --git a/test/rng.cc b/test/rng.cc index 79c29572a..47de9259a 100644 --- a/test/rng.cc +++ b/test/rng.cc @@ -2,7 +2,7 @@ #include -#include +#include #include "src/hierarchies/nnig_hierarchy.h" #include "src/utils/distributions.h" diff --git a/test/semi_hdp.cc b/test/semi_hdp.cc index b193bea14..df55b6830 100644 --- a/test/semi_hdp.cc +++ b/test/semi_hdp.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include #include "semihdp.pb.h" @@ -302,6 +302,7 @@ TEST(semihdp, sample_unique_values2) { } TEST(semihdp, sample_allocations1) { + bayesmix::Rng::Instance().seed(220122); std::vector data(2); data[0] = bayesmix::vstack({Eigen::MatrixXd::Zero(50, 1).array() + 5, diff --git a/test/write_proto.cc b/test/write_proto.cc index 60b60ba9f..da866e40a 100644 --- a/test/write_proto.cc +++ b/test/write_proto.cc @@ -3,7 +3,6 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" #include "src/hierarchies/nnig_hierarchy.h" -#include "src/hierarchies/nnw_hierarchy.h" #include "src/utils/proto_utils.h" TEST(set_state, uni_ls) { @@ -46,26 +45,3 @@ TEST(write_proto, uni_ls) { ASSERT_EQ(mean, out_mean); ASSERT_EQ(var, out_var); } - -TEST(set_state, multi_ls) { - Eigen::VectorXd mean = Eigen::VectorXd::Ones(5); - Eigen::MatrixXd prec = Eigen::MatrixXd::Identity(5, 5); - prec(1, 1) = 10.0; - - bayesmix::MultiLSState curr; - bayesmix::to_proto(mean, curr.mutable_mean()); - bayesmix::to_proto(prec, curr.mutable_prec()); - - ASSERT_EQ(curr.mean().data(0), 1.0); - ASSERT_EQ(curr.prec().data(0), 1.0); - ASSERT_EQ(curr.prec().data(6), 10.0); - - bayesmix::AlgorithmState::ClusterState clusval_in; - clusval_in.mutable_multi_ls_state()->CopyFrom(curr); - NNWHierarchy cluster; - cluster.set_state_from_proto(clusval_in); - - ASSERT_EQ(curr.mean().data(0), cluster.get_state().mean(0)); - ASSERT_EQ(curr.prec().data(0), cluster.get_state().prec(0, 0)); - ASSERT_EQ(curr.prec().data(6), cluster.get_state().prec(1, 1)); -}