diff --git a/.gitignore b/.gitignore
index cad2ac38d..a4fdf64a5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,7 +16,17 @@ sftp-config.json
*.local.*
# MacOS storage files
.DS_Store
+.dockerignore
+.ipynb_checkpoints/
+docs/_build/
+resources/benchmarks/datasets
+resources/2d
#CLion cash
.idea/
# Build debug folder
cmake-build-debug/
+
+# .old folders
+src/hierarchies/updaters/.old/
+test/.old/
+examples/gamma_hierarchy/.old/
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 28e08f633..e79f07f43 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -90,6 +90,7 @@ set(INCLUDE_PATHS
${CMAKE_CURRENT_LIST_DIR}/lib/math
${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/boost_1.72.0
${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/eigen_3.3.9
+ ${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/sundials_5.5.0/include
# TBB already included
${CMAKE_CURRENT_BINARY_DIR}
${protobuf_SOURCE_DIR}/src
diff --git a/README.md b/README.md
index d924dacd5..7cb6972a4 100644
--- a/README.md
+++ b/README.md
@@ -8,17 +8,15 @@ Current state of the software:
- `bayesmix` performs inference for mixture models of the kind
+
+
+
-where P is either the Dirichlet process or the Pitman--Yor process
-
-- We currently support univariate and multivariate location-scale mixture of Gaussian densities
-
-- Inference is carried out using algorithms such as Algorithm 2 in [Neal (2000)](http://www.stat.columbia.edu/npbayes/papers/neal_sampling.pdf)
-
-- Serialization of the MCMC chains is possible using Google's [Protocol Buffers](https://developers.google.com/protocol-buffers) aka `protobuf`
+For descriptions of the models supported in our library, discussion of software design, and examples, please refer to the following paper: https://arxiv.org/abs/2205.08144
# Installation
@@ -29,7 +27,7 @@ where P is either the Dirichlet process or the Pitman--Yor process
To install and use `bayesmix`, please `cd` to the folder to which you wish to install it, and clone this repository with the following command-line instruction:
```shell
-git clone --recurse-submodule git@github.com:bayesmix-dev/bayesmix.git
+git clone --recurse-submodules git@github.com:bayesmix-dev/bayesmix.git
```
Then, by using `cd bayesmix`, you will enter the newly downloaded folder.
@@ -39,8 +37,8 @@ To build the executable for the main file `run_mcmc.cc`, please use the followin
```shell
mkdir build
cd build
-cmake .. -DDISABLE_DOCS=on -DDISABLE_BENCHMARKS=on -DDISABLE_TESTS=on
-make run
+cmake .. -DDISABLE_DOCS=ON -DDISABLE_BENCHMARKS=ON -DDISABLE_TESTS=ON
+make run_mcmc
cd ..
```
@@ -97,3 +95,4 @@ Documentation is available at https://bayesmix.readthedocs.io.
# Contributions are welcome!
Please check out [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to collaborate with us.
+You can also head to our [issues page](https://github.com/bayesmix-dev/bayesmix/issues) to check for useful enhancements needed.
diff --git a/benchmarks/lpd_grid.cc b/benchmarks/lpd_grid.cc
index d297f5943..507b20d0b 100644
--- a/benchmarks/lpd_grid.cc
+++ b/benchmarks/lpd_grid.cc
@@ -1,7 +1,7 @@
#include
-#include
#include
+#include
#include "src/utils/distributions.h"
#include "utils.h"
@@ -37,7 +37,6 @@ Eigen::VectorXd lpdf_naive(const Eigen::MatrixXd &x,
return out;
}
-
Eigen::VectorXd lpdf_fully_optimized(const Eigen::MatrixXd &x,
const Eigen::VectorXd &mean,
const Eigen::MatrixXd &prec_chol,
@@ -84,7 +83,6 @@ static void BM_gauss_lpdf_naive(benchmark::State &state) {
}
}
-
static void BM_gauss_lpdf_fully_optimized(benchmark::State &state) {
int dim = state.range(0);
Eigen::VectorXd mean = Eigen::VectorXd::Zero(dim);
@@ -99,7 +97,6 @@ static void BM_gauss_lpdf_fully_optimized(benchmark::State &state) {
}
}
-
BENCHMARK(BM_gauss_lpdf_cov)->RangeMultiplier(2)->Range(2, 2 << 4);
BENCHMARK(BM_gauss_lpdf_cov)->RangeMultiplier(2)->Range(2, 2 << 4);
BENCHMARK(BM_gauss_lpdf_naive)->RangeMultiplier(2)->Range(2, 2 << 4);
diff --git a/benchmarks/nnw_marg_lpdf.cc b/benchmarks/nnw_marg_lpdf.cc
index 696a8a198..a99314692 100644
--- a/benchmarks/nnw_marg_lpdf.cc
+++ b/benchmarks/nnw_marg_lpdf.cc
@@ -1,7 +1,7 @@
#include
-#include
#include
+#include
#include "utils.h"
diff --git a/docs/Doxyfile.in b/docs/Doxyfile.in
index 30cda29e9..6b163fa99 100644
--- a/docs/Doxyfile.in
+++ b/docs/Doxyfile.in
@@ -1819,7 +1819,7 @@ PAPER_TYPE = a4
# If left blank no extra packages will be included.
# This tag requires that the tag GENERATE_LATEX is set to YES.
-EXTRA_PACKAGES =
+EXTRA_PACKAGES = bm
# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the
# generated LaTeX document. The header should contain everything until the first
@@ -1893,7 +1893,7 @@ USE_PDFLATEX = YES
# The default value is: NO.
# This tag requires that the tag GENERATE_LATEX is set to YES.
-LATEX_BATCHMODE = NO
+LATEX_BATCHMODE = YES
# If the LATEX_HIDE_INDICES tag is set to YES then doxygen will not include the
# index chapters (such as File Index, Compound Index, etc.) in the output.
diff --git a/docs/algorithms.rst b/docs/algorithms.rst
index c61efc257..e15aff394 100644
--- a/docs/algorithms.rst
+++ b/docs/algorithms.rst
@@ -20,6 +20,9 @@ Algorithms
.. doxygenclass:: Neal8Algorithm
:project: bayesmix
:members:
+.. doxygenclass:: SplitAndMergeAlgorithm
+ :project: bayesmix
+ :members:
.. doxygenclass:: ConditionalAlgorithm
:project: bayesmix
:members:
diff --git a/docs/conf.py b/docs/conf.py
index 3501007dd..9acbf232a 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -57,8 +57,6 @@ def configureDoxyfile(input_dir, output_dir):
html_theme = 'haiku'
-html_static_path = ['_static']
-
highlight_language = 'cpp'
imgmath_latex = 'latex'
diff --git a/docs/hierarchies.rst b/docs/hierarchies.rst
index 21c6fe1d7..c6dedcab9 100644
--- a/docs/hierarchies.rst
+++ b/docs/hierarchies.rst
@@ -6,6 +6,20 @@ Hierarchies
In our algorithms, we store a vector of hierarchies, each of which represent a parameter :math:`\theta_h`.
The hierarchy implements all the methods needed to update :math:`\theta_h`: sampling from the prior distribution :math:`P_0`, the full-conditional distribution (given the data {:math:`y_i` such that :math:`c_i = h`} ) and so on.
+In BayesMix, each choice of :math:`G_0` is implemented in a different ``PriorModel`` object and each choice of :math:`k(\cdot \mid \cdot)` in a ``Likelihood`` object, so that it is straightforward to create a new ``Hierarchy`` using one of the already implemented priors or likelihoods.
+The sampling from the full conditional of :math:`\theta_h` is performed in an ``Updater`` class.
+`State` classes are used to store parameters :math:`\theta_h` s of every mixture component.
+Their main purpose is to handle serialization and de-serialization of the state
+
+.. toctree::
+ :maxdepth: 1
+ :caption: API: hierarchies submodules
+
+ likelihoods
+ prior_models
+ updaters
+ states
+
-------------------------
Main operations performed
@@ -13,17 +27,17 @@ Main operations performed
A hierarchy must be able to perform the following operations:
-1. Sample from the prior distribution: generate :math:`\theta_h \sim P_0` [``sample_prior``]
-2. Sample from the 'full conditional' distribution: generate theta_h from the distribution :math:`p(\theta_h \mid \cdots ) \propto P_0(\theta_h) \prod_{i: c_i = h} k(y_i | \theta_h)` [``sample_full_conditional``]
-3. Update the hyperparameters involved in :math:`P_0` [``update_hypers``]
-4. Evaluate the likelihood in one point, i.e. :math:`k(x | \theta_h)` for theta_h the current value of the parameters [``like_lpdf``]
-5. When :math:`k` and :math:`P_0` are conjugate, we must also be able to compute the marginal/prior predictive distribution in one point, i.e. :math:`m(x) = \int k(x | \theta) P_0(d\theta)`, and the conditional predictive distribution :math:`m(x | \textbf{y} ) = \int k(x | \theta) P_0(d\theta | \{y_i: c_i = h\})` [``prior_pred_lpdf``, ``conditional_pred_lpdf``]
+a. Sample from the prior distribution: generate :math:`\theta_h \sim P_0` [``sample_prior``]
+b. Sample from the 'full conditional' distribution: generate theta_h from the distribution :math:`p(\theta_h \mid \cdots ) \propto P_0(\theta_h) \prod_{i: c_i = h} k(y_i | \theta_h)` [``sample_full_conditional``]
+c. Update the hyperparameters involved in :math:`P_0` [``update_hypers``]
+d. Evaluate the likelihood in one point, i.e. :math:`k(x | \theta_h)` for theta_h the current value of the parameters [``like_lpdf``]
+e. When :math:`k` and :math:`P_0` are conjugate, we must also be able to compute the marginal/prior predictive distribution in one point, i.e. :math:`m(x) = \int k(x | \theta) P_0(d\theta)`, and the conditional predictive distribution :math:`m(x | \textbf{y} ) = \int k(x | \theta) P_0(d\theta | \{y_i: c_i = h\})` [``prior_pred_lpdf``, ``conditional_pred_lpdf``]
Moreover, the following utilities are needed:
-6. write the current state :math:`\theta_h` into a appropriately defined Protobuf message [``write_state_to_proto``]
-7. restore theta_h from a given Protobuf message [``set_state_from_proto``]
-8. write the values of the hyperparameters in :math:`P_0` to a Protobuf message [``write_hypers_to_proto``]
+f. write the current state :math:`\theta_h` into a appropriately defined Protobuf message [``write_state_to_proto``]
+g. restore theta_h from a given Protobuf message [``set_state_from_proto``]
+h. write the values of the hyperparameters in :math:`P_0` to a Protobuf message [``write_hypers_to_proto``]
In each hierarchy, we also keep track of which data points are allocated to the hierarchy.
@@ -44,19 +58,18 @@ A template class ``BaseHierarchy`` inherits from ``AbstractHierarchy`` and imple
Instead, child classes must implement:
-1. ``like_lpdf``: evaluates :math:`k(x | \theta_h)`
-2. ``marg_lpdf``: evaluates m(x) given some parameters :math:`\theta_h` (could be both the hyperparameters in :math:`P_0` or the paramters given by the full conditionals)
-3. ``draw``: samples from :math:`P_0` given the parameters
-4. ``clear_summary_statistics``: clears all the summary statistics
-5. ``update_hypers``: performs the update of parameters in :math:`P_0` given all the :math:`\theta_h` (passed as a vector of protobuf Messages)
-6. ``initialize_state``: initializes the current :math:`\theta_h` given the hyperparameters in :math:`P_0`
-7. ``initialize_hypers``: initializes the hyperparameters in :math:`P_0` given their hyperprior
-8. ``update_summary_statistics``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy
-9. ``get_posterior_parameters``: returns the paramters of the full conditional distribution **possible only when** :math:`P_0` **and** :math:`k` **are conjugate**
-10. ``set_state_from_proto``
-11. ``write_state_to_proto``
-12. ``write_hypers_to_proto``
-
+a. ``like_lpdf``: evaluates :math:`k(x | \theta_h)`
+b. ``marg_lpdf``: evaluates m(x) given some parameters :math:`\theta_h` (could be both the hyperparameters in :math:`P_0` or the paramters given by the full conditionals)
+c. ``draw``: samples from :math:`P_0` given the parameters
+d. ``clear_summary_statistics``: clears all the summary statistics
+e. ``update_hypers``: performs the update of parameters in :math:`P_0` given all the :math:`\theta_h` (passed as a vector of protobuf Messages)
+f. ``initialize_state``: initializes the current :math:`\theta_h` given the hyperparameters in :math:`P_0`
+g. ``initialize_hypers``: initializes the hyperparameters in :math:`P_0` given their hyperprior
+h. ``update_summary_statistics``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy
+i. ``get_posterior_parameters``: returns the paramters of the full conditional distribution **possible only when** :math:`P_0` **and** :math:`k` **are conjugate**
+j. ``set_state_from_proto``
+k. ``write_state_to_proto``
+l. ``write_hypers_to_proto``
Note that not all of these members are declared virtual in ``AbstractHierarchy`` or ``BaseHierarchy``: this is because virtual members are only the ones that must be called from outside the ``Hierarchy``, the other ones are handled via CRTP. Not having them virtual saves a lot of lookups in the vtables.
The ``BaseHierarchy`` class takes 4 template parameters:
@@ -66,12 +79,12 @@ The ``BaseHierarchy`` class takes 4 template parameters:
3. ``Hyperparams`` is usually a struct representing the parameters in :math:`P_0`
4. ``Prior`` must be a protobuf object encoding the prior parameters.
+.. Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models.
-Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models.
--------
-Classes
--------
+----------------
+Abstract Classes
+----------------
.. doxygenclass:: AbstractHierarchy
:project: bayesmix
@@ -79,9 +92,11 @@ Classes
.. doxygenclass:: BaseHierarchy
:project: bayesmix
:members:
-.. doxygenclass:: ConjugateHierarchy
- :project: bayesmix
- :members:
+
+---------------------------------
+Classes for Conjugate Hierarchies
+---------------------------------
+
.. doxygenclass:: NNIGHierarchy
:project: bayesmix
:members:
@@ -91,3 +106,17 @@ Classes
.. doxygenclass:: LinRegUniHierarchy
:project: bayesmix
:members:
+
+-------------------------------------
+Classes for Non-Conjugate Hierarchies
+-------------------------------------
+
+.. doxygenclass:: NNxIGHierarchy
+ :project: bayesmix
+ :members:
+.. doxygenclass:: LapNIGHierarchy
+ :project: bayesmix
+ :members:
+.. doxygenclass:: FAHierarchy
+ :project: bayesmix
+ :members:
diff --git a/docs/index.rst b/docs/index.rst
index 59d19db44..271387414 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -32,7 +32,8 @@ There are currently three submodules to the ``bayesmix`` library, represented by
Further, we employ Protocol buffers for several purposes, including serialization. The list of all protos with their docs is available in the ``protos`` link below.
.. toctree::
- :maxdepth: 1
+ :maxdepth: 2
+ :titlesonly:
:caption: API: library submodules
algorithms
@@ -43,11 +44,15 @@ Further, we employ Protocol buffers for several purposes, including serializatio
utils
-
Tutorials
=========
-:doc:`tutorial`
+.. toctree::
+ :maxdepth: 1
+
+ tutorial
+
+.. :doc:`tutorial`
Python interface
diff --git a/docs/likelihoods.rst b/docs/likelihoods.rst
new file mode 100644
index 000000000..72f2d73cd
--- /dev/null
+++ b/docs/likelihoods.rst
@@ -0,0 +1,85 @@
+bayesmix/hierarchies/likelihoods
+
+Likelihoods
+===========
+
+The ``Likelihood`` sub-module represents the likelihood we have assumed for the data in a given cluster. Each ``Likelihood`` class represents the sampling model
+
+.. math::
+ y_1, \ldots, y_k \mid \bm{\tau} \stackrel{\small\mathrm{iid}}{\sim} f(\cdot \mid \bm{\tau})
+
+for a specific choice of the probability density function :math:`f`.
+
+-------------------------
+Main operations performed
+-------------------------
+
+A likelihood object must be able to perform the following operations:
+
+a. First of all, we require the ``lpdf()`` and ``lpdf\_grid()`` methods, which simply evaluate the loglikelihood in a given point or in a grid of points (also in case of a \emph{dependent} likelihood, i.e., with covariates associated to each observation) [``lpdf()`` and ``lpdf_grid``]
+b. In case you want to rely on a Metropolis-like updater, the likelihood needs to evaluation of the likelihood of the whole cluster starting from the vector of unconstrained parameters [``cluster_lpdf_from_unconstrained()``]. Observe that the ``AbstractLikelihood`` class provides two such methods, one returning a ``double`` and one returning a ``stan::math::var``. The latter is used to automatically compute the gradient of the likelihood via Stan's automatic differentiation, if needed. In practice, users do not need to implement both methods separately and can implement only one templated method
+c. manage the insertion and deletion of a datum in the cluster [``add_datum`` and ``remove_datum``]
+d. update the summary statistics associated to the likelihood [``update_summary_statistics``]. Summary statistics (when available) are used to evaluate the likelihood function on the whole cluster, as well as to perform the posterior updates of :math:`\bm{\tau}`. This usually gives a substantial speedup
+
+--------------
+Code structure
+--------------
+
+In principle, the ``Likelihood`` classes are responsible only of evaluating the log-likelihood function given a specific choice of parameters :math:`\bm{\tau}`.
+Therefore, a simple inheritance structure would seem appropriate. However, the nature of the parameters :math:`\bm{\tau}` can be very different across different models (think for instance of the difference between the univariate normal and the multivariate normal paramters). As such, we employ CRTP to manage the polymorphic nature of ``Likelihood`` classes.
+
+The class ``AbstractLikelihood`` defines the API, i.e. all the methods that need to be called from outside of a ``Likelihood`` class.
+A template class ``BaseLikelihood`` inherits from ``AbstractLikelihood`` and implements some of the necessary virtual methods, which need not be implemented by the child classes.
+
+Instead, child classes **must** implement:
+
+a. ``compute_lpdf``: evaluates :math:`k(x \mid \theta_h)`
+b. ``update_sum_stats``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy
+c. ``clear_summary_statistics``: clears all the summary statistics
+d. ``is_dependent``: defines if the given likelihood depends on covariates
+e. ``is_multivariate``: defines if the given likelihood is for multivariate data
+
+In case the likelihood needs to be used in a Metropolis-like updater, child classes **should** also implement:
+
+f. ``cluster_lpdf_from_unconstrained``: evaluates :math:`\prod_{i: c_i = h} k(x_i \mid \tilde{\theta}_h)`, where :math:`\tilde{\theta}_h` is the vector of unconstrained parameters.
+
+----------------
+Abstract Classes
+----------------
+
+.. doxygenclass:: AbstractLikelihood
+ :project: bayesmix
+ :members:
+.. doxygenclass:: BaseLikelihood
+ :project: bayesmix
+ :members:
+
+----------------------------------
+Classes for Univariate Likelihoods
+----------------------------------
+
+.. doxygenclass:: UniNormLikelihood
+ :project: bayesmix
+ :members:
+ :protected-members:
+.. doxygenclass:: UniLinRegLikelihood
+ :project: bayesmix
+ :members:
+ :protected-members:
+.. doxygenclass:: LaplaceLikelihood
+ :project: bayesmix
+ :members:
+ :protected-members:
+
+------------------------------------
+Classes for Multivariate Likelihoods
+------------------------------------
+
+.. doxygenclass:: MultiNormLikelihood
+ :project: bayesmix
+ :members:
+ :protected-members:
+.. doxygenclass:: FALikelihood
+ :project: bayesmix
+ :members:
+ :protected-members:
diff --git a/docs/mixings.rst b/docs/mixings.rst
index 67d2b1074..64f3b593f 100644
--- a/docs/mixings.rst
+++ b/docs/mixings.rst
@@ -34,6 +34,9 @@ Classes
.. doxygenclass:: PitYorMixing
:project: bayesmix
:members:
+.. doxygenclass:: MixtureFiniteMixing
+ :project: bayesmix
+ :members:
.. doxygenclass:: TruncatedSBMixing
:project: bayesmix
:members:
diff --git a/docs/prior_models.rst b/docs/prior_models.rst
new file mode 100644
index 000000000..002ab1f85
--- /dev/null
+++ b/docs/prior_models.rst
@@ -0,0 +1,87 @@
+bayesmix/hierarchies/prior_models
+
+Prior Models
+============
+
+A ``PriorModel`` represents the prior for the parameters in the likelihood, i.e.
+
+.. math::
+ \bm{\tau} \sim G_{0}
+
+with :math:`G_{0}` being a suitable prior on the parameters space. We also allow for more flexible priors adding further level of randomness (i.e. the hyperprior) on the parameter characterizing :math:`G_{0}`
+
+-------------------------
+Main operations performed
+-------------------------
+
+A likelihood object must be able to perform the following operations:
+
+a. First of all, ``lpdf()`` and ``lpdf_from_unconstrained()`` methods evaluate the log-prior density function at the current state :math:`\bm \tau` or its unconstrained representation.
+In particular, ``lpdf_from_unconstrained()`` is needed by Metropolis-like updaters.
+
+b. The ``sample()`` method generates a draw from the prior distribution. If ``hier_hypers`` is ``nullptr, the prior hyperparameter values are used.
+To allow sampling from the full conditional distribution in case of semi-congugate hierarchies, we introduce the ``hier_hypers`` parameter, which is a pointer to a ``Protobuf`` message storing the hierarchy hyperaprameters to use for the sampling.
+
+c. The ``update_hypers()`` method updates the prior hyperparameters, given the vector of all cluster states.
+
+
+--------------
+Code structure
+--------------
+
+As for the ``Likelihood`` classes we employ the Curiously Recurring Template Pattern to manage the polymorphic nature of ``PriorModel`` classes.
+
+The class ``AbstractPriorModel`` defines the API, i.e. all the methods that need to be called from outside of a ``PrioModel`` class.
+A template class ``BasePriorModel`` inherits from ``AbstractPriorModel`` and implements some of the necessary virtual methods, which need not be implemented by the child classes.
+
+Instead, child classes **must** implement:
+
+a. ``lpdf``: evaluates :math:`G_0(\theta_h)`
+b. ``sample``: samples from :math:`G_0` given a hyperparameters (passed as a pointer). If ``hier_hypers`` is ``nullptr``, the prior hyperparameter values are used.
+c. ``set_hypers_from_proto``: sets the hyperparameters from a ``Probuf`` message
+d. ``get_hypers_proto``: returns the hyperparameters as a ``Probuf`` message
+e. ``initialize_hypers``: provides a default initialization of hyperparameters
+
+In case you want to use a Metropolis-like updater, child classes **should** also implement:
+
+f. ``lpdf_from_unconstrained``: evaluates :math:`G_0(\tilde{\theta}_h)`, where :math:`\tilde{\theta}_h` is the vector of unconstrained parameters.
+
+----------------
+Abstract Classes
+----------------
+
+.. doxygenclass:: AbstractPriorModel
+ :project: bayesmix
+ :members:
+.. doxygenclass:: BasePriorModel
+ :project: bayesmix
+ :members:
+
+--------------------
+Non-abstract Classes
+--------------------
+
+.. doxygenclass:: NIGPriorModel
+ :project: bayesmix
+ :members:
+ :protected-members:
+
+.. doxygenclass:: NxIGPriorModel
+ :project: bayesmix
+ :members:
+ :protected-members:
+
+.. doxygenclass:: NWPriorModel
+ :project: bayesmix
+ :members:
+ :protected-members:
+
+.. doxygenclass:: MNIGPriorModel
+ :project: bayesmix
+ :members:
+ :protected-members:
+
+.. doxygenclass:: FAPriorModel
+ :project: bayesmix
+ :members:
+ :protected-members:
diff --git a/docs/protos.html b/docs/protos.html
index abc4c641b..26b9929a3 100644
--- a/docs/protos.html
+++ b/docs/protos.html
@@ -3,8 +3,12 @@
Protocol Documentation
-
-
+
+
-
-
+
-
Number of clusters to initialize the algorithm. It may be overridden by conditional mixings for which the number of components is fixed (e.g. TruncatedSBMixing). In this case, this value is ignored.
+ Number of clusters to initialize the algorithm. It may be
+ overridden by conditional mixings for which the number of
+ components is fixed (e.g. TruncatedSBMixing). In this case, this
+ value is ignored.
+
Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead.
-
int32
-
int
-
int
-
int32
-
int
-
integer
-
Bignum or Fixnum (as required)
-
-
-
-
int64
-
Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead.
-
int64
-
long
-
int/long
-
int64
-
long
-
integer/string
-
Bignum
-
-
-
-
uint32
-
Uses variable-length encoding.
-
uint32
-
int
-
int/long
-
uint32
-
uint
-
integer
-
Bignum or Fixnum (as required)
-
-
-
-
uint64
-
Uses variable-length encoding.
-
uint64
-
long
-
int/long
-
uint64
-
ulong
-
integer/string
-
Bignum or Fixnum (as required)
-
-
-
-
sint32
-
Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s.
-
int32
-
int
-
int
-
int32
-
int
-
integer
-
Bignum or Fixnum (as required)
-
-
-
-
sint64
-
Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s.
-
int64
-
long
-
int/long
-
int64
-
long
-
integer/string
-
Bignum
-
-
-
-
fixed32
-
Always four bytes. More efficient than uint32 if values are often greater than 2^28.
-
uint32
-
int
-
int
-
uint32
-
uint
-
integer
-
Bignum or Fixnum (as required)
-
-
-
-
fixed64
-
Always eight bytes. More efficient than uint64 if values are often greater than 2^56.
-
uint64
-
long
-
int/long
-
uint64
-
ulong
-
integer/string
-
Bignum
-
-
-
-
sfixed32
-
Always four bytes.
-
int32
-
int
-
int
-
int32
-
int
-
integer
-
Bignum or Fixnum (as required)
-
-
-
-
sfixed64
-
Always eight bytes.
-
int64
-
long
-
int/long
-
int64
-
long
-
integer/string
-
Bignum
-
-
-
-
bool
-
-
bool
-
boolean
-
boolean
-
bool
-
bool
-
boolean
-
TrueClass/FalseClass
-
-
-
-
string
-
A string must always contain UTF-8 encoded or 7-bit ASCII text.
-
string
-
String
-
str/unicode
-
string
-
string
-
string
-
String (UTF-8)
-
-
-
-
bytes
-
May contain any arbitrary sequence of bytes.
-
string
-
ByteString
-
str
-
[]byte
-
ByteString
-
string
-
String (ASCII-8BIT)
-
-
+
+
double
+
+
double
+
double
+
float
+
float64
+
double
+
float
+
Float
+
+
+
+
float
+
+
float
+
float
+
float
+
float32
+
float
+
float
+
Float
+
+
+
+
int32
+
+ Uses variable-length encoding. Inefficient for encoding negative
+ numbers – if your field is likely to have negative values, use
+ sint32 instead.
+
+
int32
+
int
+
int
+
int32
+
int
+
integer
+
Bignum or Fixnum (as required)
+
+
+
+
int64
+
+ Uses variable-length encoding. Inefficient for encoding negative
+ numbers – if your field is likely to have negative values, use
+ sint64 instead.
+
+
int64
+
long
+
int/long
+
int64
+
long
+
integer/string
+
Bignum
+
+
+
+
uint32
+
Uses variable-length encoding.
+
uint32
+
int
+
int/long
+
uint32
+
uint
+
integer
+
Bignum or Fixnum (as required)
+
+
+
+
uint64
+
Uses variable-length encoding.
+
uint64
+
long
+
int/long
+
uint64
+
ulong
+
integer/string
+
Bignum or Fixnum (as required)
+
+
+
+
sint32
+
+ Uses variable-length encoding. Signed int value. These more
+ efficiently encode negative numbers than regular int32s.
+
+
int32
+
int
+
int
+
int32
+
int
+
integer
+
Bignum or Fixnum (as required)
+
+
+
+
sint64
+
+ Uses variable-length encoding. Signed int value. These more
+ efficiently encode negative numbers than regular int64s.
+
+
int64
+
long
+
int/long
+
int64
+
long
+
integer/string
+
Bignum
+
+
+
+
fixed32
+
+ Always four bytes. More efficient than uint32 if values are often
+ greater than 2^28.
+
+
uint32
+
int
+
int
+
uint32
+
uint
+
integer
+
Bignum or Fixnum (as required)
+
+
+
+
fixed64
+
+ Always eight bytes. More efficient than uint64 if values are often
+ greater than 2^56.
+
+
uint64
+
long
+
int/long
+
uint64
+
ulong
+
integer/string
+
Bignum
+
+
+
+
sfixed32
+
Always four bytes.
+
int32
+
int
+
int
+
int32
+
int
+
integer
+
Bignum or Fixnum (as required)
+
+
+
+
sfixed64
+
Always eight bytes.
+
int64
+
long
+
int/long
+
int64
+
long
+
integer/string
+
Bignum
+
+
+
+
bool
+
+
bool
+
boolean
+
boolean
+
bool
+
bool
+
boolean
+
TrueClass/FalseClass
+
+
+
+
string
+
+ A string must always contain UTF-8 encoded or 7-bit ASCII text.
+
+
string
+
String
+
str/unicode
+
string
+
string
+
string
+
String (UTF-8)
+
+
+
+
bytes
+
May contain any arbitrary sequence of bytes.
+
string
+
ByteString
+
str
+
[]byte
+
ByteString
+
string
+
String (ASCII-8BIT)
+
-
diff --git a/docs/states.rst b/docs/states.rst
new file mode 100644
index 000000000..68de1feb1
--- /dev/null
+++ b/docs/states.rst
@@ -0,0 +1,40 @@
+bayesmix/hierarchies/likelihoods/states
+
+States
+======
+
+``States`` are classes used to store parameters :math:`\theta_h` of every mixture component.
+Their main purpose is to handle serialization and de-serialization of the state.
+Moreover, they allow to go from the constrained to the unconstrained representation of the parameters (and viceversa) and compute the associated determinant of the Jacobian appearing in the change of density formula.
+
+
+--------------
+Code Structure
+--------------
+
+All classes must inherit from the ``BaseState`` class
+
+.. doxygenclass:: State::BaseState
+ :project: bayesmix
+ :members:
+
+Depending on the chosen ``Updater``, the unconstrained representation might not be needed, and the methods ``get_unconstrained()``, ``set_from_unconstrained()`` and ``log_det_jac()`` might never be called.
+Therefore, we do not force users to implement them.
+Instead, the ``set_from_proto()`` and ``get_as_proto()`` are fundamental as they allow the interaction with Google's Protocol Buffers library.
+
+-------------
+State Classes
+-------------
+
+.. doxygenclass:: State::UniLS
+ :project: bayesmix
+ :members:
+
+.. doxygenclass:: State::MultiLS
+ :project: bayesmix
+ :members:
+
+.. doxygenclass:: State::FA
+ :project: bayesmix
+ :members:
+ :protected-members:
diff --git a/docs/updaters.rst b/docs/updaters.rst
new file mode 100644
index 000000000..ac3fb19c0
--- /dev/null
+++ b/docs/updaters.rst
@@ -0,0 +1,76 @@
+bayesmix/hierarchies/updaters
+
+Updaters
+========
+
+An ``Updater`` implements the machinery to provide a sampling from the full conditional distribution of a given hierarchy.
+
+The only operation performed is ``draw`` that samples from the full conditional, either exactly or via Markov chain Monte Carlo.
+
+.. doxygenclass:: AbstractUpdater
+ :project: bayesmix
+ :members:
+
+--------------
+Code Structure
+--------------
+
+We distinguish between semi-conjugate updaters and the metropolis-like updaters.
+
+
+Semi Conjugate Updaters
+-----------------------
+
+A semi-conjugate updater can be used when the full conditional distribution has the same form of the prior. Therefore, to sample from the full conditional, it is sufficient to call the ``draw`` method of the prior, but with an updated set of hyperparameters.
+
+The class ``SemiConjugateUpdater`` defines the API
+
+.. doxygenclass:: SemiConjugateUpdater
+ :project: bayesmix
+ :members:
+
+Classes inheriting from this one should only implement the ``compute_posterior_hypers(...)`` member function.
+
+
+Metropolis-like Updaters
+------------------------
+
+A Metropolis updater uses the Metropolis-Hastings algorithm (or its variations) to sample from the full conditional density.
+
+.. doxygenclass:: MetropolisUpdater
+ :project: bayesmix
+ :members:
+
+
+Classes inheriting from this one should only implement the ``sample_proposal(...)`` method, which samples from the porposal distribution, and the ``proposal_lpdf`` one, which evaluates the proposal density log-probability density function.
+
+---------------
+Updater Classes
+---------------
+
+.. doxygenclass:: RandomWalkUpdater
+ :project: bayesmix
+ :members:
+.. doxygenclass:: MalaUpdater
+ :project: bayesmix
+ :members:
+.. doxygenclass:: NNIGUpdater
+ :project: bayesmix
+ :members:
+ :protected-members:
+.. doxygenclass:: NNxIGUpdater
+ :project: bayesmix
+ :members:
+ :protected-members:
+.. doxygenclass:: NNWUpdater
+ :project: bayesmix
+ :members:
+ :protected-members:
+.. doxygenclass:: MNIGUpdater
+ :project: bayesmix
+ :members:
+ :protected-members:
+.. doxygenclass:: FAUpdater
+ :project: bayesmix
+ :members:
+ :protected-members:
diff --git a/docs/utils.rst b/docs/utils.rst
index 89c7e4533..45420f464 100644
--- a/docs/utils.rst
+++ b/docs/utils.rst
@@ -18,17 +18,20 @@ Distribution-related utilities
:project: bayesmix
----------------------------------------------
-``Eigen`` input-output and matrix manipulation
+``Eigen`` matrix manipulation utilities
----------------------------------------------
.. doxygenfile:: eigen_utils.h
:project: bayesmix
- :members:
+
+--------------------------------
+``Eigen`` input-output utilities
+--------------------------------
.. doxygenfile:: io_utils.h
:project: bayesmix
--------------------------
-``protobuf`` input-output
--------------------------
+-----------------------------------
+``protobuf`` input-output utilities
+-----------------------------------
.. doxygenfile:: proto_utils.h
:project: bayesmix
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 733662a26..09934a9f0 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -1,9 +1,12 @@
cmake_minimum_required(VERSION 3.13.0)
project(examples_bayesmix)
-add_executable(run_gamma $
+add_executable(run_gamma $
gamma_hierarchy/run_gamma_gamma.cc
- gamma_hierarchy/gamma_gamma_hier.h
+ gamma_hierarchy/gammagamma_hierarchy.h
+ gamma_hierarchy/gamma_likelihood.h
+ gamma_hierarchy/gamma_prior_model.h
+ gamma_hierarchy/gammagamma_updater.h
)
target_include_directories(run_gamma PUBLIC ${INCLUDE_PATHS})
diff --git a/examples/gamma_hierarchy/gamma_gamma_hier.h b/examples/gamma_hierarchy/gamma_gamma_hier.h
deleted file mode 100644
index ecd43325b..000000000
--- a/examples/gamma_hierarchy/gamma_gamma_hier.h
+++ /dev/null
@@ -1,140 +0,0 @@
-#ifndef BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_
-#define BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_
-
-#include
-#include
-#include
-
-#include
-#include
-#include
-
-#include "hierarchy_prior.pb.h"
-
-namespace GammaGamma {
-//! Custom container for State values
-struct State {
- double rate;
-};
-
-//! Custom container for Hyperparameters values
-struct Hyperparams {
- double shape, rate_alpha, rate_beta;
-};
-}; // namespace GammaGamma
-
-class GammaGammaHierarchy
- : public ConjugateHierarchy {
- public:
- GammaGammaHierarchy(const double shape, const double rate_alpha,
- const double rate_beta)
- : shape(shape), rate_alpha(rate_alpha), rate_beta(rate_beta) {
- create_empty_prior();
- }
- ~GammaGammaHierarchy() = default;
-
- double like_lpdf(const Eigen::RowVectorXd &datum) const override {
- return stan::math::gamma_lpdf(datum(0), hypers->shape, state.rate);
- }
-
- double marg_lpdf(
- const GammaGamma::Hyperparams ¶ms, const Eigen::RowVectorXd &datum,
- const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const {
- throw std::runtime_error("marg_lpdf() not implemented");
- return 0;
- }
-
- GammaGamma::State draw(const GammaGamma::Hyperparams ¶ms) {
- return GammaGamma::State{stan::math::gamma_rng(
- params.rate_alpha, params.rate_beta, bayesmix::Rng::Instance().get())};
- }
-
- void update_summary_statistics(const Eigen::RowVectorXd &datum,
- const bool add) {
- if (add) {
- data_sum += datum(0);
- ndata += 1;
- } else {
- data_sum -= datum(0);
- ndata -= 1;
- }
- }
-
- //! Computes and return posterior hypers given data currently in this cluster
- GammaGamma::Hyperparams compute_posterior_hypers() {
- GammaGamma::Hyperparams out;
- out.shape = hypers->shape;
- out.rate_alpha = hypers->rate_alpha + hypers->shape * ndata;
- out.rate_beta = hypers->rate_beta + data_sum;
- return out;
- }
-
- void initialize_state() override {
- state.rate = hypers->rate_alpha / hypers->rate_beta;
- }
-
- void initialize_hypers() {
- hypers->shape = shape;
- hypers->rate_alpha = rate_alpha;
- hypers->rate_beta = rate_beta;
- }
-
- //! Removes every data point from this cluster
- void clear_summary_statistics() {
- data_sum = 0;
- ndata = 0;
- }
-
- bool is_multivariate() const override { return false; }
-
- void set_state_from_proto(const google::protobuf::Message &state_) override {
- auto &statecast = google::protobuf::internal::down_cast<
- const bayesmix::AlgorithmState::ClusterState &>(state_);
- state.rate = statecast.general_state().data()[0];
- set_card(statecast.cardinality());
- }
-
- std::shared_ptr get_state_proto()
- const override {
- bayesmix::Vector state_;
- state_.mutable_data()->Add(state.rate);
-
- auto out = std::make_unique();
- out->mutable_general_state()->CopyFrom(state_);
- return out;
- }
-
- void update_hypers(const std::vector
- &states) override {
- return;
- }
-
- void write_hypers_to_proto(
- google::protobuf::Message *const out) const override {
- return;
- }
-
- void set_hypers_from_proto(
- const google::protobuf::Message &state_) override {
- return;
- }
-
- std::shared_ptr get_hypers_proto()
- const override {
- return nullptr;
- }
-
- bayesmix::HierarchyId get_id() const override {
- return bayesmix::HierarchyId::UNKNOWN_HIERARCHY;
- }
-
- protected:
- double data_sum = 0;
- int ndata = 0;
-
- double shape, rate_alpha, rate_beta;
-};
-
-#endif // BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_
diff --git a/examples/gamma_hierarchy/gamma_likelihood.h b/examples/gamma_hierarchy/gamma_likelihood.h
new file mode 100644
index 000000000..25b3103c5
--- /dev/null
+++ b/examples/gamma_hierarchy/gamma_likelihood.h
@@ -0,0 +1,82 @@
+#ifndef BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_
+#define BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_
+
+#include
+
+#include
+#include
+#include
+
+#include "algorithm_state.pb.h"
+#include "src/hierarchies/likelihoods/base_likelihood.h"
+#include "src/hierarchies/likelihoods/states/base_state.h"
+
+namespace State {
+class Gamma : public BaseState {
+ public:
+ double shape, rate;
+ using ProtoState = bayesmix::AlgorithmState::ClusterState;
+
+ ProtoState get_as_proto() const override {
+ ProtoState out;
+ out.mutable_general_state()->set_size(2);
+ out.mutable_general_state()->mutable_data()->Add(shape);
+ out.mutable_general_state()->mutable_data()->Add(rate);
+ return out;
+ }
+
+ void set_from_proto(const ProtoState &state_, bool update_card) override {
+ if (update_card) {
+ card = state_.cardinality();
+ }
+ shape = state_.general_state().data()[0];
+ rate = state_.general_state().data()[1];
+ }
+};
+} // namespace State
+
+class GammaLikelihood : public BaseLikelihood {
+ public:
+ GammaLikelihood() = default;
+ ~GammaLikelihood() = default;
+ bool is_multivariate() const override { return false; };
+ bool is_dependent() const override { return false; };
+ void clear_summary_statistics() override;
+
+ // Getters and Setters
+ int get_ndata() const { return ndata; };
+ double get_shape() const { return state.shape; };
+ double get_data_sum() const { return data_sum; };
+
+ protected:
+ double compute_lpdf(const Eigen::RowVectorXd &datum) const override;
+ void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override;
+
+ //! Sum of data in the cluster
+ double data_sum = 0;
+ //! number of data in the cluster
+ int ndata = 0;
+};
+
+/* DEFINITIONS */
+void GammaLikelihood::clear_summary_statistics() {
+ data_sum = 0;
+ ndata = 0;
+}
+
+double GammaLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const {
+ return stan::math::gamma_lpdf(datum(0), state.shape, state.rate);
+}
+
+void GammaLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum,
+ bool add) {
+ if (add) {
+ data_sum += datum(0);
+ ndata += 1;
+ } else {
+ data_sum -= datum(0);
+ ndata -= 1;
+ }
+}
+
+#endif // BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_
diff --git a/examples/gamma_hierarchy/gamma_prior_model.h b/examples/gamma_hierarchy/gamma_prior_model.h
new file mode 100644
index 000000000..a49bafadc
--- /dev/null
+++ b/examples/gamma_hierarchy/gamma_prior_model.h
@@ -0,0 +1,107 @@
+#ifndef BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_
+#define BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_
+
+#include
+#include
+#include
+#include
+
+#include "algorithm_state.pb.h"
+#include "gamma_likelihood.h"
+#include "hierarchy_prior.pb.h"
+#include "src/hierarchies/priors/base_prior_model.h"
+#include "src/utils/rng.h"
+
+namespace Hyperparams {
+struct Gamma {
+ double rate_alpha, rate_beta;
+};
+} // namespace Hyperparams
+
+class GammaPriorModel
+ : public BasePriorModel {
+ public:
+ using AbstractPriorModel::ProtoHypers;
+ using AbstractPriorModel::ProtoHypersPtr;
+
+ GammaPriorModel(double shape_ = -1, double rate_alpha_ = -1,
+ double rate_beta_ = -1);
+ ~GammaPriorModel() = default;
+
+ double lpdf(const google::protobuf::Message &state_) override;
+
+ State::Gamma sample(ProtoHypersPtr hier_hypers = nullptr) override;
+
+ void update_hypers(const std::vector
+ &states) override {
+ return;
+ };
+
+ void set_hypers_from_proto(
+ const google::protobuf::Message &hypers_) override;
+
+ ProtoHypersPtr get_hypers_proto() const override;
+ double get_shape() const { return shape; };
+
+ protected:
+ double shape, rate_alpha, rate_beta;
+ void initialize_hypers() override;
+};
+
+/* DEFINITIONS */
+GammaPriorModel::GammaPriorModel(double shape_, double rate_alpha_,
+ double rate_beta_)
+ : shape(shape_), rate_alpha(rate_alpha_), rate_beta(rate_beta_) {
+ create_empty_prior();
+};
+
+double GammaPriorModel::lpdf(const google::protobuf::Message &state_) {
+ double rate = downcast_state(state_).general_state().data()[1];
+ return stan::math::gamma_lpdf(rate, hypers->rate_alpha, hypers->rate_beta);
+}
+
+State::Gamma GammaPriorModel::sample(ProtoHypersPtr hier_hypers) {
+ auto &rng = bayesmix::Rng::Instance().get();
+ State::Gamma out;
+
+ auto params = (hier_hypers) ? hier_hypers->general_state()
+ : get_hypers_proto()->general_state();
+ double rate_alpha = params.data()[0];
+ double rate_beta = params.data()[1];
+ out.shape = shape;
+ out.rate = stan::math::gamma_rng(rate_alpha, rate_beta, rng);
+ return out;
+}
+
+void GammaPriorModel::set_hypers_from_proto(
+ const google::protobuf::Message &hypers_) {
+ auto &hyperscast = downcast_hypers(hypers_).general_state();
+ hypers->rate_alpha = hyperscast.data()[0];
+ hypers->rate_beta = hyperscast.data()[1];
+};
+
+GammaPriorModel::ProtoHypersPtr GammaPriorModel::get_hypers_proto() const {
+ ProtoHypersPtr out = std::make_shared();
+ out->mutable_general_state()->mutable_data()->Add(hypers->rate_alpha);
+ out->mutable_general_state()->mutable_data()->Add(hypers->rate_beta);
+ return out;
+};
+
+void GammaPriorModel::initialize_hypers() {
+ hypers->rate_alpha = rate_alpha;
+ hypers->rate_beta = rate_beta;
+
+ // Checks
+ if (shape <= 0) {
+ throw std::runtime_error("shape must be positive");
+ }
+ if (rate_alpha <= 0) {
+ throw std::runtime_error("rate_alpha must be positive");
+ }
+ if (rate_beta <= 0) {
+ throw std::runtime_error("rate_beta must be positive");
+ }
+}
+
+#endif // BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_
diff --git a/examples/gamma_hierarchy/gammagamma_hierarchy.h b/examples/gamma_hierarchy/gammagamma_hierarchy.h
new file mode 100644
index 000000000..6fe135533
--- /dev/null
+++ b/examples/gamma_hierarchy/gammagamma_hierarchy.h
@@ -0,0 +1,47 @@
+#ifndef BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_
+#define BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_
+
+#include "gamma_likelihood.h"
+#include "gamma_prior_model.h"
+#include "gammagamma_updater.h"
+#include "hierarchy_id.pb.h"
+#include "src/hierarchies/base_hierarchy.h"
+
+class GammaGammaHierarchy
+ : public BaseHierarchy {
+ public:
+ GammaGammaHierarchy(double shape_, double rate_alpha_, double rate_beta_) {
+ auto prior =
+ std::make_shared(shape_, rate_alpha_, rate_beta_);
+ set_prior(prior);
+ };
+ ~GammaGammaHierarchy() = default;
+
+ bayesmix::HierarchyId get_id() const override {
+ return bayesmix::HierarchyId::UNKNOWN_HIERARCHY;
+ }
+
+ void set_default_updater() {
+ updater = std::make_shared();
+ }
+
+ void initialize_state() override {
+ // Get hypers
+ auto hypers = prior->get_hypers();
+ // Initialize likelihood state
+ State::Gamma state;
+ state.shape = prior->get_shape();
+ state.rate = hypers.rate_alpha / hypers.rate_beta;
+ like->set_state(state);
+ };
+
+ double marg_lpdf(ProtoHypersPtr hier_params,
+ const Eigen::RowVectorXd &datum) const override {
+ throw(
+ std::runtime_error("marg_lpdf() not implemented for this hierarchy"));
+ return 0;
+ }
+};
+
+#endif // BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_
diff --git a/examples/gamma_hierarchy/gammagamma_updater.h b/examples/gamma_hierarchy/gammagamma_updater.h
new file mode 100644
index 000000000..dd05b6ffd
--- /dev/null
+++ b/examples/gamma_hierarchy/gammagamma_updater.h
@@ -0,0 +1,49 @@
+#ifndef BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_
+#define BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_
+
+#include "gamma_likelihood.h"
+#include "gamma_prior_model.h"
+#include "src/hierarchies/updaters/semi_conjugate_updater.h"
+
+class GammaGammaUpdater
+ : public SemiConjugateUpdater {
+ public:
+ GammaGammaUpdater() = default;
+ ~GammaGammaUpdater() = default;
+
+ bool is_conjugate() const override { return true; };
+
+ ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood& like,
+ AbstractPriorModel& prior) override;
+};
+
+/* DEFINITIONS */
+AbstractUpdater::ProtoHypersPtr GammaGammaUpdater::compute_posterior_hypers(
+ AbstractLikelihood& like, AbstractPriorModel& prior) {
+ // Likelihood and Prior downcast
+ auto& likecast = downcast_likelihood(like);
+ auto& priorcast = downcast_prior(prior);
+
+ // Getting required quantities from likelihood and prior
+ int card = likecast.get_card();
+ double data_sum = likecast.get_data_sum();
+ double ndata = likecast.get_ndata();
+ double shape = priorcast.get_shape();
+ auto hypers = priorcast.get_hypers();
+
+ // No update possible
+ if (card == 0) {
+ return priorcast.get_hypers_proto();
+ }
+ // Compute posterior hyperparameters
+ double rate_alpha_new = hypers.rate_alpha + shape * ndata;
+ double rate_beta_new = hypers.rate_beta + data_sum;
+
+ // Proto conversion
+ ProtoHypers out;
+ out.mutable_general_state()->mutable_data()->Add(rate_alpha_new);
+ out.mutable_general_state()->mutable_data()->Add(rate_beta_new);
+ return std::make_shared(out);
+}
+
+#endif // BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_
diff --git a/examples/gamma_hierarchy/run_gamma_gamma.cc b/examples/gamma_hierarchy/run_gamma_gamma.cc
index f6502b72d..d8e5cf0f5 100644
--- a/examples/gamma_hierarchy/run_gamma_gamma.cc
+++ b/examples/gamma_hierarchy/run_gamma_gamma.cc
@@ -1,6 +1,6 @@
#include
-#include "gamma_gamma_hier.h"
+#include "gammagamma_hierarchy.h"
#include "src/includes.h"
Eigen::MatrixXd simulate_data(const unsigned int ndata) {
diff --git a/executables/run_mcmc.cc b/executables/run_mcmc.cc
index eccc423c3..8a47aabb8 100644
--- a/executables/run_mcmc.cc
+++ b/executables/run_mcmc.cc
@@ -167,6 +167,7 @@ int main(int argc, char *argv[]) {
mixing->get_mutable_prior());
bayesmix::read_proto_from_file(args.get("--hier-args"),
hier->get_mutable_prior());
+ hier->initialize();
// Read data matrices
Eigen::MatrixXd data =
diff --git a/python/notebooks/gaussian_mix_NNxIG.ipynb b/python/notebooks/gaussian_mix_NNxIG.ipynb
new file mode 100644
index 000000000..96e3fa370
--- /dev/null
+++ b/python/notebooks/gaussian_mix_NNxIG.ipynb
@@ -0,0 +1,149 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6c73fa6a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "from bayesmixpy import run_mcmc"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "79825dfc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"BAYESMIX_EXE\"] = \"../../build/run_mcmc\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "64a83071",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Generate data\n",
+ "data = np.concatenate([\n",
+ " np.random.normal(loc=3, scale=1, size=100),\n",
+ " np.random.normal(loc=-3, scale=1, size=100),\n",
+ "])\n",
+ "\n",
+ "# Plot data\n",
+ "plt.hist(data)\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "13df394d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Hierarchy settings\n",
+ "hier_params = \\\n",
+ "\"\"\"\n",
+ "fixed_values {\n",
+ " mean: 0.0\n",
+ " var: 10.0\n",
+ " shape: 2.0\n",
+ " scale: 2.0\n",
+ "}\n",
+ "\"\"\"\n",
+ "\n",
+ "# Mixing settings\n",
+ "mix_params = \\\n",
+ "\"\"\"\n",
+ "fixed_value {\n",
+ " totalmass: 1.0\n",
+ "}\n",
+ "\"\"\"\n",
+ "\n",
+ "# Algorithm settings\n",
+ "algo_params = \\\n",
+ "\"\"\"\n",
+ "algo_id: \"Neal8\"\n",
+ "rng_seed: 20201124\n",
+ "iterations: 2000\n",
+ "burnin: 1000\n",
+ "init_num_clusters: 3\n",
+ "neal8_n_aux: 3\n",
+ "\"\"\"\n",
+ "\n",
+ "# Evaluation grid\n",
+ "dens_grid = np.linspace(-6.5, 6.5, 1000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "12505b6e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Fit model using bayesmixpy\n",
+ "eval_dens, n_clus, clus_chain, best_clus = run_mcmc(\"NNxIG\",\"DP\", data,\n",
+ " hier_params, mix_params, algo_params,\n",
+ " dens_grid, return_num_clusters=True,\n",
+ " return_clusters=True, return_best_clus=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8470fc0a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
+ "\n",
+ "# Posterior distribution of clusters\n",
+ "x, y = np.unique(n_clus, return_counts=True)\n",
+ "axes[0].bar(x, y / y.sum())\n",
+ "axes[0].set_xticks(x)\n",
+ "axes[0].set_title(\"Posterior distribution of the number of clusters\")\n",
+ "\n",
+ "# Plot mean posterior density\n",
+ "axes[1].plot(dens_grid, np.exp(np.mean(eval_dens, axis=0)))\n",
+ "axes[1].hist(data, alpha=0.3, density=True)\n",
+ "for c in np.unique(best_clus):\n",
+ " data_in_clus = data[best_clus == c]\n",
+ " axes[1].scatter(data_in_clus, np.zeros_like(data_in_clus) + 0.01)\n",
+ "axes[1].set_title(\"Posterior density estimate\")\n",
+ "\n",
+ "# Show results\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/python/notebooks/gaussian_mix_uni.ipynb b/python/notebooks/gaussian_mix_uni.ipynb
index 2c6f7f4f0..7a9c71ed1 100644
--- a/python/notebooks/gaussian_mix_uni.ipynb
+++ b/python/notebooks/gaussian_mix_uni.ipynb
@@ -210,24 +210,24 @@
"neal2_algo = \"\"\"\n",
"algo_id: \"Neal2\"\n",
"rng_seed: 20201124\n",
- "iterations: 10\n",
- "burnin: 5\n",
+ "iterations: 100\n",
+ "burnin: 50\n",
"init_num_clusters: 3\n",
"\"\"\"\n",
"\n",
"neal3_algo = \"\"\"\n",
"algo_id: \"Neal3\"\n",
"rng_seed: 20201124\n",
- "iterations: 10\n",
- "burnin: 5\n",
+ "iterations: 100\n",
+ "burnin: 50\n",
"init_num_clusters: 3\n",
"\"\"\"\n",
"\n",
"neal8_algo = \"\"\"\n",
"algo_id: \"Neal8\"\n",
"rng_seed: 20201124\n",
- "iterations: 10\n",
- "burnin: 5\n",
+ "iterations: 100\n",
+ "burnin: 50\n",
"init_num_clusters: 3\n",
"neal8_n_aux: 3\n",
"\"\"\"\n",
@@ -244,7 +244,7 @@
"\n",
"`return_clusters=False, return_num_clusters=False, return_best_clus=False`\n",
"\n",
- "Observe that the number of iterations is extremely small! In real problems, you might want to set the burnin at least to 1000 iterations and the total number of iterations to at leas 2000."
+ "Observe that the number of iterations is extremely small! In real problems, you might want to set the burnin at least to 1000 iterations and the total number of iterations to at least 2000."
]
},
{
@@ -449,13 +449,6 @@
"source": [
"np.var(data)"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
@@ -463,7 +456,7 @@
"hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
},
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -477,7 +470,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.10"
+ "version": "3.8.10"
}
},
"nbformat": 4,
diff --git a/src/algorithms/base_algorithm.cc b/src/algorithms/base_algorithm.cc
index 19597a406..bfe9ac76c 100644
--- a/src/algorithms/base_algorithm.cc
+++ b/src/algorithms/base_algorithm.cc
@@ -1,7 +1,7 @@
#include "base_algorithm.h"
-#include
#include
+#include
#include
#include "algorithm_params.pb.h"
diff --git a/src/algorithms/base_algorithm.h b/src/algorithms/base_algorithm.h
index 988260f5e..b42df0611 100644
--- a/src/algorithms/base_algorithm.h
+++ b/src/algorithms/base_algorithm.h
@@ -3,8 +3,8 @@
#include
-#include
#include
+#include
#include
#include "algorithm_id.pb.h"
diff --git a/src/algorithms/conditional_algorithm.cc b/src/algorithms/conditional_algorithm.cc
index 30090013a..aebc2cdf5 100644
--- a/src/algorithms/conditional_algorithm.cc
+++ b/src/algorithms/conditional_algorithm.cc
@@ -1,7 +1,7 @@
#include "conditional_algorithm.h"
-#include
#include
+#include
#include "algorithm_state.pb.h"
#include "base_algorithm.h"
diff --git a/src/algorithms/conditional_algorithm.h b/src/algorithms/conditional_algorithm.h
index f99b01b1e..f7bed0391 100644
--- a/src/algorithms/conditional_algorithm.h
+++ b/src/algorithms/conditional_algorithm.h
@@ -1,34 +1,40 @@
#ifndef BAYESMIX_ALGORITHMS_CONDITIONAL_ALGORITHM_H_
#define BAYESMIX_ALGORITHMS_CONDITIONAL_ALGORITHM_H_
-#include
#include
+#include
#include "base_algorithm.h"
#include "src/collectors/base_collector.h"
-//! Template class for a conditional sampler deriving from `BaseAlgorithm`.
-
-//! This template class implements a generic Gibbs sampling conditional
-//! algorithm as the child of the `BaseAlgorithm` class.
-//! A mixture model sampled from a conditional algorithm can be expressed as
-//! x_i | c_i, phi_1, ..., phi_k ~ f(x_i|phi_(c_i)) (data likelihood);
-//! phi_1, ... phi_k ~ G (unique values);
-//! c_1, ... c_n | w_1, ..., w_k ~ Cat(w_1, ... w_k) (cluster allocations);
-//! w_1, ..., w_k ~ p(w_1, ..., w_k) (mixture weights)
-//! where f(x | phi_j) is a density for each value of phi_j, the c_i take
-//! values in {1, ..., k} and w_1, ..., w_k are nonnegative weights whose sum
-//! is a.s. 1, i.e. p(w_1, ... w_k) is a probability distribution on the k-1
-//! dimensional unit simplex).
-//! In this library, each phi_j is represented as an `Hierarchy` object (which
-//! inherits from `AbstractHierarchy`), that also holds the information related
-//! to the base measure `G` is (see `AbstractHierarchy`).
-//! The weights (w_1, ..., w_k) are represented as a `Mixing` object, which
-//! inherits from `AbstractMixing`.
-
-//! The state of a conditional algorithm consists of the unique values, the
-//! cluster allocations and the mixture weights. The former two are stored in
-//! this class, while the weights are stored in the `Mixing` object.
+/**
+ * Template class for a conditional sampler deriving from `BaseAlgorithm`.
+ *
+ * This template class implements a generic Gibbs sampling conditional
+ * algorithm as the child of the `BaseAlgorithm` class.
+ * A mixture model sampled from a conditional algorithm can be expressed as
+ *
+ * \f[
+ * x_i \mid c_i, \theta_1, \dots, \theta_k &\sim f(x_i \mid \theta_{c_i}) \\
+ * \theta_1, \dots, \theta_k &\sim G_0 \\
+ * c_1, \dots, c_n \mid w_1, \dots, w_k &\sim \text{Cat}(w_1, \dots, w_k) \\
+ * w_1, \dots, w_k &\sim p(w_1, \dots, w_k)
+ * \f]
+ *
+ * where \f$ f(x \mid \theta_j) \f$ is a density for each value of \f$ \theta_j
+ * \f$, \f$ c_i \f$ take values in \f$ \{1, \dots, k\} \f$ and \f$ w_1, \dots,
+ * w_k \f$ are nonnegative weights whose sum is a.s. 1, i.e. \f$ p(w_1, \dots,
+ * w_k) \f$ is a probability distribution on the k-1 dimensional unit simplex).
+ * In this library, each \f$ \theta_j \f$ is represented as an `Hierarchy`
+ * object (which inherits from `AbstractHierarchy`), that also holds the
+ * information related to the base measure \f$ G \f$ is (see
+ * `AbstractHierarchy`). The weights \f$ (w_1, \dots, w_k) \f$ are represented
+ * as a `Mixing` object, which inherits from `AbstractMixing`.
+ *
+ * The state of a conditional algorithm consists of the unique values, the
+ * cluster allocations and the mixture weights. The former two are stored in
+ * this class, while the weights are stored in the `Mixing` object.
+ */
class ConditionalAlgorithm : public BaseAlgorithm {
public:
diff --git a/src/algorithms/marginal_algorithm.cc b/src/algorithms/marginal_algorithm.cc
index 279402c1a..ddf4340de 100644
--- a/src/algorithms/marginal_algorithm.cc
+++ b/src/algorithms/marginal_algorithm.cc
@@ -1,8 +1,8 @@
#include "marginal_algorithm.h"
-#include
#include
#include
+#include
#include "algorithm_state.pb.h"
#include "base_algorithm.h"
diff --git a/src/algorithms/marginal_algorithm.h b/src/algorithms/marginal_algorithm.h
index 44fd4164c..95fcddd92 100644
--- a/src/algorithms/marginal_algorithm.h
+++ b/src/algorithms/marginal_algorithm.h
@@ -1,36 +1,44 @@
#ifndef BAYESMIX_ALGORITHMS_MARGINAL_ALGORITHM_H_
#define BAYESMIX_ALGORITHMS_MARGINAL_ALGORITHM_H_
-#include
#include
+#include
#include "base_algorithm.h"
#include "src/collectors/base_collector.h"
#include "src/hierarchies/abstract_hierarchy.h"
-//! Template class for a marginal sampler deriving from `BaseAlgorithm`.
-
-//! This template class implements a generic Gibbs sampling marginal algorithm
-//! as the child of the `BaseAlgorithm` class.
-//! A mixture model sampled from a Marginal Algorithm can be expressed as
-//! x_i | c_i, phi_1, ..., phi_k ~ f(x_i|phi_(c_i)) (data likelihood);
-//! phi_1, ... phi_k ~ G (unique values);
-//! c_1, ... c_n ~ EPPF(c_1, ... c_n) (cluster allocations);
-//! where f(x | phi_j) is a density for each value of phi_j and the c_i take
-//! values in {1, ..., k}.
-//! Depending on the actual implementation, the algorithm might require
-//! the kernel/likelihood f(x | phi) and G(phi) to be conjugagte or not.
-//! In the former case, a `ConjugateHierarchy` must be specified.
-//! In this library, each phi_j is represented as an `Hierarchy` object (which
-//! inherits from `AbstractHierarchy`), that also holds the information related
-//! to the base measure `G` is (see `AbstractHierarchy`). The EPPF is instead
-//! represented as a `Mixing` object, which inherits from `AbstractMixing`.
-//!
-//! The state of a marginal algorithm only consists of allocations and unique
-//! values. In this class of algorithms, the local lpdf estimate for a single
-//! iteration is a weighted average of likelihood values corresponding to each
-//! component (i.e. cluster), with the weights being based on its cardinality,
-//! and of the marginal component, which depends on the specific algorithm.
+/**
+ * Template class for a marginal sampler deriving from `BaseAlgorithm`.
+ *
+ * This template class implements a generic Gibbs sampling marginal algorithm
+ * as the child of the `BaseAlgorithm` class.
+ * A mixture model sampled from a Marginal Algorithm can be expressed as
+ *
+ * \f[
+ * x_i \mid c_i, \theta_1, \dots, \theta_k &\sim f(x_i \mid \theta_{c_i})
+ * \\
+ * \theta_1, \dots, \theta_k &\sim G_0 \\
+ * c_1, \dots, c_n &\sim EPPF(c_1, \dots, c_n)
+ * \f]
+ *
+ * where \f$ f(x \mid \theta_j) \f$ is a density for each value of \f$ \theta_j
+ * \f$ and \f$ c_i \f$ take values in \f$ {1, \dots, k} \f$. Depending on the
+ * actual implementation, the algorithm might require the kernel/likelihood \f$
+ * f(x \mid \theta) \f$ and \f$ G_0(\phi) \f$ to be conjugagte or not. In the
+ * former case, a conjugate hierarchy must be specified. In this library, each
+ * \f$ \theta_j \f$ is represented as an `Hierarchy` object (which inherits
+ * from `AbstractHierarchy`), that also holds the information related to the
+ * base measure \f$ G_0 \f$ is (see `AbstractHierarchy`). The \f$ EPPF \f$ is
+ * instead represented as a `Mixing` object, which inherits from
+ * `AbstractMixing`.
+ *
+ * The state of a marginal algorithm only consists of allocations and unique
+ * values. In this class of algorithms, the local lpdf estimate for a single
+ * iteration is a weighted average of likelihood values corresponding to each
+ * component (i.e. cluster), with the weights being based on its cardinality,
+ * and of the marginal component, which depends on the specific algorithm.
+ */
class MarginalAlgorithm : public BaseAlgorithm {
public:
@@ -45,9 +53,12 @@ class MarginalAlgorithm : public BaseAlgorithm {
protected:
//! Computes marginal contribution of the given cluster to the lpdf estimate
- //! @param hier Pointer to the `Hierarchy` object representing the cluster
- //! @param grid Grid of row points on which the density is to be evaluated
- //! @return The marginal component of the estimate
+ //! @param hier Pointer to the `Hierarchy` object representing the
+ //! cluster
+ //! @param grid Grid of row points on which the density is to be
+ //! evaluated
+ //! @param covariate (Optional) covariate vectors associated to data
+ //! @return The marginal component of the estimate
virtual Eigen::VectorXd lpdf_marginal_component(
const std::shared_ptr hier,
const Eigen::MatrixXd &grid,
diff --git a/src/algorithms/neal2_algorithm.cc b/src/algorithms/neal2_algorithm.cc
index 2973cc8dc..d67550c07 100644
--- a/src/algorithms/neal2_algorithm.cc
+++ b/src/algorithms/neal2_algorithm.cc
@@ -1,8 +1,8 @@
#include "neal2_algorithm.h"
-#include
#include
#include
+#include
#include
#include "algorithm_id.pb.h"
diff --git a/src/algorithms/neal2_algorithm.h b/src/algorithms/neal2_algorithm.h
index d9f63dc6a..a51e8504d 100644
--- a/src/algorithms/neal2_algorithm.h
+++ b/src/algorithms/neal2_algorithm.h
@@ -1,8 +1,8 @@
#ifndef BAYESMIX_ALGORITHMS_NEAL2_ALGORITHM_H_
#define BAYESMIX_ALGORITHMS_NEAL2_ALGORITHM_H_
-#include
#include
+#include
#include "algorithm_id.pb.h"
#include "marginal_algorithm.h"
diff --git a/src/algorithms/neal3_algorithm.cc b/src/algorithms/neal3_algorithm.cc
index 7e209b81f..e7bf9712a 100644
--- a/src/algorithms/neal3_algorithm.cc
+++ b/src/algorithms/neal3_algorithm.cc
@@ -1,7 +1,7 @@
#include "neal3_algorithm.h"
-#include
#include
+#include
#include "hierarchy_id.pb.h"
#include "mixing_id.pb.h"
diff --git a/src/algorithms/neal8_algorithm.cc b/src/algorithms/neal8_algorithm.cc
index c38bc803c..6ee6d25fd 100644
--- a/src/algorithms/neal8_algorithm.cc
+++ b/src/algorithms/neal8_algorithm.cc
@@ -1,8 +1,8 @@
#include "neal8_algorithm.h"
-#include
#include
#include
+#include
#include "algorithm_id.pb.h"
#include "algorithm_state.pb.h"
diff --git a/src/algorithms/neal8_algorithm.h b/src/algorithms/neal8_algorithm.h
index 5ff026809..edbdb2e4a 100644
--- a/src/algorithms/neal8_algorithm.h
+++ b/src/algorithms/neal8_algorithm.h
@@ -1,8 +1,8 @@
#ifndef BAYESMIX_ALGORITHMS_NEAL8_ALGORITHM_H_
#define BAYESMIX_ALGORITHMS_NEAL8_ALGORITHM_H_
-#include
#include
+#include
#include
#include "algorithm_id.pb.h"
diff --git a/src/algorithms/semihdp_sampler.h b/src/algorithms/semihdp_sampler.h
index e684c6dd1..3558a3ae9 100644
--- a/src/algorithms/semihdp_sampler.h
+++ b/src/algorithms/semihdp_sampler.h
@@ -1,9 +1,9 @@
#ifndef SRC_ALGORITHMS_SEMIHDP_SAMPLER_H
#define SRC_ALGORITHMS_SEMIHDP_SAMPLER_H
-#include
#include
#include
+#include
#include
#include
diff --git a/src/algorithms/split_and_merge_algorithm.h b/src/algorithms/split_and_merge_algorithm.h
index 43d595c53..61d5c0d7d 100644
--- a/src/algorithms/split_and_merge_algorithm.h
+++ b/src/algorithms/split_and_merge_algorithm.h
@@ -1,8 +1,8 @@
#ifndef BAYESMIX_ALGORITHMS_SPLIT_AND_MERGE_ALGORITHM_H_
#define BAYESMIX_ALGORITHMS_SPLIT_AND_MERGE_ALGORITHM_H_
-#include
#include
+#include
#include "algorithm_id.pb.h"
#include "marginal_algorithm.h"
diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt
index 0207bd04f..42a45fd87 100644
--- a/src/hierarchies/CMakeLists.txt
+++ b/src/hierarchies/CMakeLists.txt
@@ -2,15 +2,14 @@ target_sources(bayesmix
PUBLIC
abstract_hierarchy.h
base_hierarchy.h
- conjugate_hierarchy.h
- lin_reg_uni_hierarchy.h
- lin_reg_uni_hierarchy.cc
nnig_hierarchy.h
- nnig_hierarchy.cc
+ nnxig_hierarchy.h
nnw_hierarchy.h
- nnw_hierarchy.cc
+ lin_reg_uni_hierarchy.h
fa_hierarchy.h
- fa_hierarchy.cc
lapnig_hierarchy.h
- lapnig_hierarchy.cc
)
+
+add_subdirectory(likelihoods)
+add_subdirectory(priors)
+add_subdirectory(updaters)
diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h
index b69716a0c..947f362cf 100644
--- a/src/hierarchies/abstract_hierarchy.h
+++ b/src/hierarchies/abstract_hierarchy.h
@@ -3,64 +3,88 @@
#include
-#include
#include
#include
#include
#include
+#include
#include "algorithm_state.pb.h"
#include "hierarchy_id.pb.h"
+#include "src/hierarchies/likelihoods/abstract_likelihood.h"
+#include "src/hierarchies/priors/abstract_prior_model.h"
+#include "src/hierarchies/updaters/abstract_updater.h"
#include "src/utils/rng.h"
-//! Abstract base class for a hierarchy object.
-//! This class is the basis for a curiously recurring template pattern (CRTP)
-//! for `Hierarchy` objects, and is solely composed of interface functions for
-//! derived classes to use. For more information about this pattern, as well
-//! the list of methods required for classes in this inheritance tree, please
-//! refer to the README.md file included in this folder.
-
-//! This abstract class represents a Bayesian hierarchical model:
-//! x_1, ..., x_n \sim f(x | \theta)
-//! theta \sim G
-//! A Hierarchy object can compute the following quantities:
-//! 1- the likelihood log-probability density function
-//! 2- the prior predictive probability: \int_\Theta f(x | theta) G(d\theta)
-//! (for conjugate models only)
-//! 3- the posterior predictive probability
-//! \int_\Theta f(x | theta) G(d\theta | x_1, ..., x_n)
-//! (for conjugate models only)
-//! Moreover, the Hierarchy knows how to sample from the full conditional of
-//! theta, possibly in an approximate way.
-//!
-//! In the context of our Gibbs samplers, an hierarchy represents the parameter
-//! value associated to a certain cluster, and also knows which observations
-//! are allocated to that cluster.
-//! Moreover, hyperparameters and (possibly) hyperpriors associated to them can
-//! be shared across multiple Hierarchies objects via a shared pointer.
-//! In conjunction with a single `Mixing` object, a collection of `Hierarchy`
-//! objects completely defines a mixture model, and these two parts can be
-//! chosen independently of each other.
-//! Communication with other classes, as well as storage of some relevant
-//! values, is performed via appropriately defined Protobuf messages (see for
-//! instance the proto/ls_state.proto and proto/hierarchy_prior.proto files)
-//! and their relative class methods.
+/**
+ * Abstract base class for a hierarchy object.
+ * This class is the basis for a curiously recurring template pattern (CRTP)
+ * for `Hierarchy` objects, and is solely composed of interface functions for
+ * derived classes to use. For more information about this pattern, as well
+ * the list of methods required for classes in this inheritance tree, please
+ * refer to the README.md file included in this folder.
+ *
+ * This abstract class represents a Bayesian hierarchical model:
+ *
+ * \f[
+ * x_1,\dots,x_n &\sim f(x \mid \theta) \\
+ * \theta &\sim G
+ * \f]
+ *
+ * A Hierarchy object can compute the following quantities:
+ *
+ * 1. the likelihood log-probability density function
+ * 2. the prior predictive probability: \f$ \int_\Theta f(x \mid \theta)
+ * G(d\theta) \f$ (for conjugate models only)
+ * 3. the posterior predictive probability
+ * \f$ \int_\Theta f(x \mid \theta) G(d\theta \mid x_1, ..., x_n) \f$
+ * (for conjugate models only)
+ *
+ * Moreover, the Hierarchy knows how to sample from the full conditional of
+ * \f$ \theta \f$, possibly in an approximate way.
+ *
+ * In the context of our Gibbs samplers, an hierarchy represents the parameter
+ * value associated to a certain cluster, and also knows which observations
+ * are allocated to that cluster.
+ *
+ * Moreover, hyperparameters and (possibly) hyperpriors associated to them can
+ * be shared across multiple Hierarchies objects via a shared pointer.
+ * In conjunction with a single `Mixing` object, a collection of `Hierarchy`
+ * objects completely defines a mixture model, and these two parts can be
+ * chosen independently of each other.
+ *
+ * Communication with other classes, as well as storage of some relevant
+ * values, is performed via appropriately defined Protobuf messages (see for
+ * instance the `proto/ls_state.proto` and `proto/hierarchy_prior.proto` files)
+ * and their relative class methods.
+ */
class AbstractHierarchy {
public:
+ //! Set the update algorithm for the current hierarchy
+ virtual void set_updater(std::shared_ptr updater_) = 0;
+
+ //! Returns (a pointer to) the likelihood for the current hierarchy
+ virtual std::shared_ptr get_likelihood() = 0;
+
+ //! Returns (a pointer to) the prior model for the current hierarchy
+ virtual std::shared_ptr get_prior() = 0;
+
+ //! Default destructor
virtual ~AbstractHierarchy() = default;
//! Returns an independent, data-less copy of this object
virtual std::shared_ptr clone() const = 0;
+ //! Returns an independent, data-less copy of this object
virtual std::shared_ptr deep_clone() const = 0;
// EVALUATION FUNCTIONS FOR SINGLE POINTS
//! Public wrapper for `like_lpdf()` methods
- double get_like_lpdf(
+ virtual double get_like_lpdf(
const Eigen::RowVectorXd &datum,
const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const {
- if (is_dependent()) {
+ if (is_dependent() and covariate.size() != 0) {
return like_lpdf(datum, covariate);
} else {
return like_lpdf(datum);
@@ -74,8 +98,13 @@ class AbstractHierarchy {
virtual double prior_pred_lpdf(
const Eigen::RowVectorXd &datum,
const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const {
- throw std::runtime_error(
- "Cannot call prior_pred_lpdf() from a non-conjugate hierarchy");
+ if (is_conjugate()) {
+ throw std::runtime_error(
+ "prior_pred_lpdf() not implemented for this hierarchy");
+ } else {
+ throw std::runtime_error(
+ "Cannot call prior_pred_lpdf() from a non-conjugate hierarchy");
+ }
}
//! Evaluates the log-conditional predictive distr. of data in a single point
@@ -85,8 +114,14 @@ class AbstractHierarchy {
virtual double conditional_pred_lpdf(
const Eigen::RowVectorXd &datum,
const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const {
- throw std::runtime_error(
- "Cannot call conditional_pred_lpdf() from a non-conjugate hierarchy");
+ if (is_conjugate()) {
+ throw std::runtime_error(
+ "conditional_pred_lpdf() not implemented for this hierarchy");
+ } else {
+ throw std::runtime_error(
+ "Cannot call conditional_pred_lpdf() from a non-conjugate "
+ "hierarchy");
+ }
}
// EVALUATION FUNCTIONS FOR GRIDS OF POINTS
@@ -105,8 +140,13 @@ class AbstractHierarchy {
virtual Eigen::VectorXd prior_pred_lpdf_grid(
const Eigen::MatrixXd &data,
const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const {
- throw std::runtime_error(
- "Cannot call prior_pred_lpdf_grid() from a non-conjugate hierarchy");
+ if (is_conjugate()) {
+ throw std::runtime_error(
+ "prior_pred_lpdf_grid() not implemented for this hierarchy");
+ } else {
+ throw std::runtime_error(
+ "Cannot call prior_pred_lpdf_grid() from a non-conjugate hierarchy");
+ }
}
//! Evaluates the log-prior predictive distr. of data in a grid of points
@@ -116,9 +156,14 @@ class AbstractHierarchy {
virtual Eigen::VectorXd conditional_pred_lpdf_grid(
const Eigen::MatrixXd &data,
const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const {
- throw std::runtime_error(
- "Cannot call conditional_pred_lpdf_grid() from a non-conjugate "
- "hierarchy");
+ if (is_conjugate()) {
+ throw std::runtime_error(
+ "conditional_pred_lpdf_grid() not implemented for this hierarchy");
+ } else {
+ throw std::runtime_error(
+ "Cannot call conditional_pred_lpdf_grid() from a non-conjugate "
+ "hierarchy");
+ }
}
// SAMPLING FUNCTIONS
@@ -208,12 +253,12 @@ class AbstractHierarchy {
virtual bool is_multivariate() const = 0;
//! Returns whether the hierarchy depends on covariate values or not
- virtual bool is_dependent() const { return false; }
+ virtual bool is_dependent() const = 0;
//! Returns whether the hierarchy represents a conjugate model or not
- virtual bool is_conjugate() const { return false; }
+ virtual bool is_conjugate() const = 0;
- //! Main function that initializes members to appropriate values
+ //! Sets the (pointer to) the dataset in the cluster
virtual void set_dataset(const Eigen::MatrixXd *const dataset) = 0;
protected:
@@ -227,7 +272,8 @@ class AbstractHierarchy {
throw std::runtime_error(
"Cannot call like_lpdf() from a non-dependent hierarchy");
} else {
- throw std::runtime_error("like_lpdf() not implemented");
+ throw std::runtime_error(
+ "like_lpdf() not implemented for this hierarchy");
}
}
@@ -239,7 +285,8 @@ class AbstractHierarchy {
throw std::runtime_error(
"Cannot call like_lpdf() from a dependent hierarchy");
} else {
- throw std::runtime_error("like_lpdf() not implemented");
+ throw std::runtime_error(
+ "like_lpdf() not implemented for this hierarchy");
}
}
@@ -255,7 +302,8 @@ class AbstractHierarchy {
"Cannot call update_summary_statistics() from a non-dependent "
"hierarchy");
} else {
- throw std::runtime_error("update_summary_statistics() not implemented");
+ throw std::runtime_error(
+ "update_summary_statistics() not implemented for this hierarchy");
}
}
@@ -269,7 +317,8 @@ class AbstractHierarchy {
"Cannot call update_summary_statistics() from a dependent "
"hierarchy");
} else {
- throw std::runtime_error("update_summary_statistics() not implemented");
+ throw std::runtime_error(
+ "update_summary_statistics() not implemented for this hierarchy");
}
}
};
diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h
index c4c4f8953..86ea08a0a 100644
--- a/src/hierarchies/base_hierarchy.h
+++ b/src/hierarchies/base_hierarchy.h
@@ -3,16 +3,17 @@
#include
-#include
#include
#include
#include
-#include
+#include
#include "abstract_hierarchy.h"
#include "algorithm_state.pb.h"
#include "hierarchy_id.pb.h"
+#include "src/utils/covariates_getter.h"
#include "src/utils/rng.h"
+#include "updaters/target_lpdf_unconstrained.h"
//! Base template class for a hierarchy object.
@@ -22,326 +23,358 @@
//! class for further information). It includes class members and some more
//! functions which could not be implemented in the non-templatized abstract
//! class.
-//! See, for instance, `ConjugateHierarchy` and `NNIGHierarchy` to better
-//! understand the CRTP patterns.
+//! See, for instance, `NNIGHierarchy` to better understand the CRTP patterns.
//! @tparam Derived Name of the implemented derived class
-//! @tparam State Class name of the container for state values
-//! @tparam Hyperparams Class name of the container for hyperprior parameters
-//! @tparam Prior Class name of the container for prior parameters
+//! @tparam Likelihood Class name of the likelihood model for the hierarchy
+//! @tparam PriorModel Class name of the prior model for the hierarchy
-template
+template
class BaseHierarchy : public AbstractHierarchy {
+ protected:
+ //! Container for the likelihood of the hierarchy
+ std::shared_ptr like = std::make_shared();
+
+ //! Container for the prior model of the hierarchy
+ std::shared_ptr prior = std::make_shared();
+
+ //! Container for the update algorithm
+ std::shared_ptr updater;
+
public:
- BaseHierarchy() = default;
+ // Useful type aliases
+ using HyperParams = decltype(prior->get_hypers());
+ using ProtoHypersPtr = AbstractUpdater::ProtoHypersPtr;
+ using ProtoHypers = AbstractUpdater::ProtoHypers;
+
+ //! Constructor that allows the specification of Likelihood, PriorModel and
+ //! Updater for a given Hierarchy
+ BaseHierarchy(std::shared_ptr like_ = nullptr,
+ std::shared_ptr prior_ = nullptr,
+ std::shared_ptr updater_ = nullptr) {
+ if (like_) {
+ set_likelihood(like_);
+ }
+ if (prior_) {
+ set_prior(prior_);
+ }
+ if (updater_) {
+ set_updater(updater_);
+ } else {
+ static_cast(this)->set_default_updater();
+ }
+ }
+
+ //! Default destructor
~BaseHierarchy() = default;
+ //! Sets the likelihood for the current hierarchy
+ void set_likelihood(std::shared_ptr like_) /*override*/ {
+ like = std::static_pointer_cast(like_);
+ }
+
+ //! Sets the prior model for the current hierarchy
+ void set_prior(std::shared_ptr prior_) /*override*/ {
+ prior = std::static_pointer_cast(prior_);
+ }
+
+ //! Sets the update algorithm for the current hierarchy
+ void set_updater(std::shared_ptr updater_) override {
+ updater = updater_;
+ };
+
+ //! Returns (a pointer to) the likelihood for the current hierarchy
+ std::shared_ptr get_likelihood() override {
+ return like;
+ }
+
+ //! Returns (a pointer to) the prior model for the current hierarchy.
+ std::shared_ptr get_prior() override { return prior; }
+
//! Returns an independent, data-less copy of this object
std::shared_ptr clone() const override {
+ // Create copy of the hierarchy
auto out = std::make_shared(static_cast(*this));
- out->clear_data();
- out->clear_summary_statistics();
+ // Cloning each component class
+ out->set_likelihood(std::static_pointer_cast(like->clone()));
+ out->set_prior(std::static_pointer_cast(prior->clone()));
return out;
- }
+ };
- //! Returns an independent, data-less copy of this object
+ //! Returns an independent, data-less deep copy of this object
std::shared_ptr deep_clone() const override {
+ // Create copy of the hierarchy
auto out = std::make_shared(static_cast(*this));
+ // Simple clone for Likelihood is enough
+ out->set_likelihood(std::static_pointer_cast(like->clone()));
+ // Deep clone required for PriorModel
+ out->set_prior(std::static_pointer_cast(prior->deep_clone()));
+ return out;
+ }
- out->clear_data();
- out->clear_summary_statistics();
+ //! Public wrapper for `like_lpdf()` methods
+ double get_like_lpdf(const Eigen::RowVectorXd &datum,
+ const Eigen::RowVectorXd &covariate =
+ Eigen::RowVectorXd(0)) const override {
+ return like->lpdf(datum, covariate);
+ }
- out->create_empty_prior();
- std::shared_ptr new_prior(prior->New());
- new_prior->CopyFrom(*prior.get());
- out->get_mutable_prior()->CopyFrom(*new_prior.get());
+ //! Evaluates the log-likelihood of data in a grid of points
+ //! @param data Grid of points (by row) which are to be evaluated
+ //! @param covariates (Optional) covariate vectors associated to data
+ //! @return The evaluation of the lpdf
+ Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data,
+ const Eigen::MatrixXd &covariates =
+ Eigen::MatrixXd(0, 0)) const override {
+ return like->lpdf_grid(data, covariates);
+ };
+
+ //! Public wrapper for `marg_lpdf()` methods
+ double get_marg_lpdf(
+ ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum,
+ const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const {
+ if (this->is_dependent() and covariate.size() != 0) {
+ return marg_lpdf(hier_params, datum, covariate);
+ } else {
+ return marg_lpdf(hier_params, datum);
+ }
+ }
- out->create_empty_hypers();
- auto curr_hypers_proto = get_hypers_proto();
- out->set_hypers_from_proto(*curr_hypers_proto.get());
- out->initialize();
- return out;
+ //! Evaluates the log-prior predictive distribution of data in a single point
+ //! @param datum Point which is to be evaluated
+ //! @param covariate (Optional) covariate vector associated to datum
+ //! @return The evaluation of the lpdf
+ double prior_pred_lpdf(const Eigen::RowVectorXd &datum,
+ const Eigen::RowVectorXd &covariate =
+ Eigen::RowVectorXd(0)) const override {
+ return get_marg_lpdf(prior->get_hypers_proto(), datum, covariate);
}
- //! Evaluates the log-likelihood of data in a grid of points
+ //! Evaluates the log-prior predictive distr. of data in a grid of points
//! @param data Grid of points (by row) which are to be evaluated
//! @param covariates (Optional) covariate vectors associated to data
//! @return The evaluation of the lpdf
- virtual Eigen::VectorXd like_lpdf_grid(
+ Eigen::VectorXd prior_pred_lpdf_grid(
const Eigen::MatrixXd &data,
- const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0,
- 0)) const override;
+ const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/)
+ const override {
+ Eigen::VectorXd lpdf(data.rows());
+ if (covariates.cols() == 0) {
+ // Pass null value as covariate
+ for (int i = 0; i < data.rows(); i++) {
+ lpdf(i) = static_cast(this)->prior_pred_lpdf(
+ data.row(i), Eigen::RowVectorXd(0));
+ }
+ } else if (covariates.rows() == 1) {
+ // Use unique covariate
+ for (int i = 0; i < data.rows(); i++) {
+ lpdf(i) = static_cast(this)->prior_pred_lpdf(
+ data.row(i), covariates.row(0));
+ }
+ } else {
+ // Use different covariates
+ for (int i = 0; i < data.rows(); i++) {
+ lpdf(i) = static_cast(this)->prior_pred_lpdf(
+ data.row(i), covariates.row(i));
+ }
+ }
+ return lpdf;
+ }
- //! Generates new state values from the centering prior distribution
- void sample_prior() override {
- state = static_cast(this)->draw(*hypers);
+ //! Evaluates the log-conditional predictive distr. of data in a single point
+ //! @param datum Point which is to be evaluated
+ //! @param covariate (Optional) covariate vector associated to datum
+ //! @return The evaluation of the lpdf
+ double conditional_pred_lpdf(const Eigen::RowVectorXd &datum,
+ const Eigen::RowVectorXd &covariate =
+ Eigen::RowVectorXd(0)) const override {
+ return get_marg_lpdf(updater->compute_posterior_hypers(*like, *prior),
+ datum, covariate);
+ }
+
+ //! Evaluates the log-prior predictive distr. of data in a grid of points
+ //! @param data Grid of points (by row) which are to be evaluated
+ //! @param covariates (Optional) covariate vectors associated to data
+ //! @return The evaluation of the lpdf
+ Eigen::VectorXd conditional_pred_lpdf_grid(
+ const Eigen::MatrixXd &data,
+ const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/)
+ const override {
+ Eigen::VectorXd lpdf(data.rows());
+ covariates_getter cov_getter(covariates);
+ for (int i = 0; i < data.rows(); i++) {
+ lpdf(i) = static_cast(this)->conditional_pred_lpdf(
+ data.row(i), cov_getter(i));
+ }
+ return lpdf;
}
+ //! Generates new state values from the centering prior distribution
+ void sample_prior() override { like->set_state(prior->sample(), false); };
+
+ //! Generates new state values from the centering posterior distribution
+ //! @param update_params Save posterior hypers after the computation?
+ void sample_full_cond(bool update_params = false) override {
+ updater->draw(*like, *prior, update_params);
+ };
+
//! Overloaded version of sample_full_cond(bool), mainly used for debugging
- virtual void sample_full_cond(
+ void sample_full_cond(
const Eigen::MatrixXd &data,
- const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override;
+ const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override {
+ like->clear_data();
+ like->clear_summary_statistics();
+
+ covariates_getter cov_getter(covariates);
+ for (int i = 0; i < data.rows(); i++) {
+ static_cast(this)->add_datum(i, data.row(i), false,
+ cov_getter(i));
+ }
+ static_cast(this)->sample_full_cond(true);
+ };
+
+ //! Updates hyperparameter values given a vector of cluster states
+ void update_hypers(const std::vector
+ &states) override {
+ prior->update_hypers(states);
+ };
+
+ //! Returns the class of the current state
+ auto get_state() const -> decltype(like->get_state()) {
+ return like->get_state();
+ };
//! Returns the current cardinality of the cluster
- int get_card() const override { return card; }
+ int get_card() const override { return like->get_card(); };
//! Returns the logarithm of the current cardinality of the cluster
- double get_log_card() const override { return log_card; }
+ double get_log_card() const override { return like->get_log_card(); };
//! Returns the indexes of data points belonging to this cluster
- std::set get_data_idx() const override { return cluster_data_idx; }
+ std::set get_data_idx() const override { return like->get_data_idx(); };
+
+ //! Writes current state to a Protobuf message and return a shared_ptr
+ //! New hierarchies have to first modify the field 'oneof val' in the
+ //! AlgoritmState::ClusterState message by adding the appropriate type
+ std::shared_ptr get_state_proto()
+ const override {
+ return std::make_shared(
+ like->get_state().get_as_proto());
+ }
//! Returns a pointer to the Protobuf message of the prior of this cluster
google::protobuf::Message *get_mutable_prior() override {
- if (prior == nullptr) {
- create_empty_prior();
- }
- return prior.get();
- }
+ return prior->get_mutable_prior();
+ };
//! Writes current state to a Protobuf message by pointer
- void write_state_to_proto(
- google::protobuf::Message *const out) const override;
+ void write_state_to_proto(google::protobuf::Message *out) const override {
+ like->write_state_to_proto(out);
+ };
//! Writes current values of the hyperparameters to a Protobuf message by
//! pointer
- void write_hypers_to_proto(
- google::protobuf::Message *const out) const override;
+ void write_hypers_to_proto(google::protobuf::Message *out) const override {
+ prior->write_hypers_to_proto(out);
+ };
- //! Returns the struct of the current state
- State get_state() const { return state; }
+ //! Read and set state values from a given Protobuf message
+ void set_state_from_proto(const google::protobuf::Message &state_) override {
+ like->set_state_from_proto(state_);
+ };
- //! Returns the struct of the current prior hyperparameters
- Hyperparams get_hypers() const { return *hypers; }
-
- //! Returns the struct of the current posterior hyperparameters
- Hyperparams get_posterior_hypers() const { return posterior_hypers; }
+ //! Read and set hyperparameter values from a given Protobuf message
+ void set_hypers_from_proto(
+ const google::protobuf::Message &state_) override {
+ prior->set_hypers_from_proto(state_);
+ };
//! Adds a datum and its index to the hierarchy
void add_datum(
const int id, const Eigen::RowVectorXd &datum,
const bool update_params = false,
- const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override;
+ const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override {
+ like->add_datum(id, datum, covariate);
+ if (update_params) {
+ updater->save_posterior_hypers(
+ updater->compute_posterior_hypers(*like, *prior));
+ }
+ };
//! Removes a datum and its index from the hierarchy
void remove_datum(
const int id, const Eigen::RowVectorXd &datum,
const bool update_params = false,
- const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override;
+ const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override {
+ like->remove_datum(id, datum, covariate);
+ if (update_params) {
+ updater->save_posterior_hypers(
+ updater->compute_posterior_hypers(*like, *prior));
+ }
+ };
//! Main function that initializes members to appropriate values
void initialize() override {
- hypers = std::make_shared();
- check_prior_is_set();
- initialize_hypers();
+ prior->initialize();
+ if (is_conjugate()) {
+ updater->save_posterior_hypers(prior->get_hypers_proto());
+ }
initialize_state();
- posterior_hypers = *hypers;
- clear_data();
- clear_summary_statistics();
- }
-
- //! Sets the (pointer to the) dataset matrix
- void set_dataset(const Eigen::MatrixXd *const dataset) override {
- dataset_ptr = dataset;
- }
+ like->clear_data();
+ like->clear_summary_statistics();
+ };
- protected:
- //! Raises an error if the prior pointer is not initialized
- void check_prior_is_set() const {
- if (prior == nullptr) {
- throw std::invalid_argument("Hierarchy prior was not provided");
- }
- }
+ //! Returns whether the hierarchy models multivariate data or not
+ bool is_multivariate() const override { return like->is_multivariate(); };
- //! Re-initializes the prior of the hierarchy to a newly created object
- void create_empty_prior() { prior.reset(new Prior); }
+ //! Returns whether the hierarchy depends on covariate values or not
+ bool is_dependent() const override { return like->is_dependent(); };
- //! Re-initializes the hypers of the hierarchy to a newly created object
- void create_empty_hypers() { hypers.reset(new Hyperparams); }
+ //! Returns whether the hierarchy represents a conjugate model or not
+ bool is_conjugate() const override { return updater->is_conjugate(); };
- //! Sets the cardinality of the cluster
- void set_card(const int card_) {
- card = card_;
- log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_);
+ //! Sets the (pointer to the) dataset matrix
+ void set_dataset(const Eigen::MatrixXd *const dataset) override {
+ like->set_dataset(dataset);
}
+ protected:
//! Initializes state parameters to appropriate values
virtual void initialize_state() = 0;
- //! Writes current value of hyperparameters to a Protobuf message and
- //! return a shared_ptr.
- //! New hierarchies have to first modify the field 'oneof val' in the
- //! AlgoritmState::HierarchyHypers message by adding the appropriate type
- virtual std::shared_ptr
- get_hypers_proto() const = 0;
-
- //! Initializes hierarchy hyperparameters to appropriate values
- virtual void initialize_hypers() = 0;
-
- //! Resets cardinality and indexes of data in this cluster
- void clear_data() {
- set_card(0);
- cluster_data_idx = std::set();
- }
-
- virtual void clear_summary_statistics() = 0;
-
- //! Down-casts the given generic proto message to a ClusterState proto
- bayesmix::AlgorithmState::ClusterState *downcast_state(
- google::protobuf::Message *const state_) const {
- return google::protobuf::internal::down_cast<
- bayesmix::AlgorithmState::ClusterState *>(state_);
- }
-
- //! Down-casts the given generic proto message to a ClusterState proto
- const bayesmix::AlgorithmState::ClusterState &downcast_state(
- const google::protobuf::Message &state_) const {
- return google::protobuf::internal::down_cast<
- const bayesmix::AlgorithmState::ClusterState &>(state_);
- }
-
- //! Down-casts the given generic proto message to a HierarchyHypers proto
- bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers(
- google::protobuf::Message *const state_) const {
- return google::protobuf::internal::down_cast<
- bayesmix::AlgorithmState::HierarchyHypers *>(state_);
- }
-
- //! Down-casts the given generic proto message to a HierarchyHypers proto
- const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers(
- const google::protobuf::Message &state_) const {
- return google::protobuf::internal::down_cast<
- const bayesmix::AlgorithmState::HierarchyHypers &>(state_);
- }
-
- //! Container for state values
- State state;
-
- //! Container for prior hyperparameters values
- std::shared_ptr hypers;
-
- //! Container for posterior hyperparameters values
- Hyperparams posterior_hypers;
-
- //! Pointer to a Protobuf prior object for this class
- std::shared_ptr prior;
-
- //! Set of indexes of data points belonging to this cluster
- std::set cluster_data_idx;
-
- //! Current cardinality of this cluster
- int card = 0;
-
- //! Logarithm of current cardinality of this cluster
- double log_card = stan::math::NEGATIVE_INFTY;
-
- //! Pointer to the dataset matrix for the mixture model
- const Eigen::MatrixXd *dataset_ptr = nullptr;
-};
-
-template
-void BaseHierarchy::add_datum(
- const int id, const Eigen::RowVectorXd &datum,
- const bool update_params /*= false*/,
- const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) {
- assert(cluster_data_idx.find(id) == cluster_data_idx.end());
- card += 1;
- log_card = std::log(card);
- static_cast(this)->update_ss(datum, covariate, true);
- cluster_data_idx.insert(id);
- if (update_params) {
- static_cast(this)->save_posterior_hypers();
- }
-}
-
-template
-void BaseHierarchy::remove_datum(
- const int id, const Eigen::RowVectorXd &datum,
- const bool update_params /*= false*/,
- const Eigen::RowVectorXd &covariate /* = Eigen::RowVectorXd(0)*/) {
- static_cast(this)->update_ss(datum, covariate, false);
- set_card(card - 1);
- auto it = cluster_data_idx.find(id);
- assert(it != cluster_data_idx.end());
- cluster_data_idx.erase(it);
- if (update_params) {
- static_cast(this)->save_posterior_hypers();
- }
-}
-
-template
-void BaseHierarchy::write_state_to_proto(
- google::protobuf::Message *const out) const {
- std::shared_ptr state_ =
- get_state_proto();
- auto *out_cast = downcast_state(out);
- out_cast->CopyFrom(*state_.get());
- out_cast->set_cardinality(card);
-}
-
-template
-void BaseHierarchy::write_hypers_to_proto(
- google::protobuf::Message *const out) const {
- std::shared_ptr hypers_ =
- get_hypers_proto();
- auto *out_cast = downcast_hypers(out);
- out_cast->CopyFrom(*hypers_.get());
-}
-
-template
-Eigen::VectorXd
-BaseHierarchy::like_lpdf_grid(
- const Eigen::MatrixXd &data,
- const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const {
- Eigen::VectorXd lpdf(data.rows());
- if (covariates.cols() == 0) {
- // Pass null value as covariate
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->get_like_lpdf(
- data.row(i), Eigen::RowVectorXd(0));
- }
- } else if (covariates.rows() == 1) {
- // Use unique covariate
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->get_like_lpdf(
- data.row(i), covariates.row(0));
- }
- } else {
- // Use different covariates
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->get_like_lpdf(
- data.row(i), covariates.row(i));
+ //! Evaluates the log-marginal distribution of data in a single point
+ //! @param hier_params Pointer to the container (prior or posterior)
+ //! hyperparameter values
+ //! @param datum Point which is to be evaluated
+ //! @return The evaluation of the lpdf
+ virtual double marg_lpdf(ProtoHypersPtr hier_params,
+ const Eigen::RowVectorXd &datum) const {
+ if (!is_conjugate()) {
+ throw std::runtime_error(
+ "Call marg_lpdf() for a non-conjugate hierarchy");
+ } else {
+ throw std::runtime_error(
+ "marg_lpdf() not implemented for this hierarchy");
}
}
- return lpdf;
-}
-
-template
-void BaseHierarchy::sample_full_cond(
- const Eigen::MatrixXd &data,
- const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) {
- clear_data();
- clear_summary_statistics();
- if (covariates.cols() == 0) {
- // Pass null value as covariate
- for (int i = 0; i < data.rows(); i++) {
- static_cast(this)->add_datum(i, data.row(i), false,
- Eigen::RowVectorXd(0));
- }
- } else if (covariates.rows() == 1) {
- // Use unique covariate
- for (int i = 0; i < data.rows(); i++) {
- static_cast(this)->add_datum(i, data.row(i), false,
- covariates.row(0));
- }
- } else {
- // Use different covariates
- for (int i = 0; i < data.rows(); i++) {
- static_cast(this)->add_datum(i, data.row(i), false,
- covariates.row(i));
+
+ //! Evaluates the log-marginal distribution of data in a single point
+ //! @param hier_params Pointer to the container of (prior or posterior)
+ //! hyperparameter values
+ //! @param datum Point which is to be evaluated
+ //! @param covariate Covariate vector associated to datum
+ //! @return The evaluation of the lpdf
+ virtual double marg_lpdf(ProtoHypersPtr hier_params,
+ const Eigen::RowVectorXd &datum,
+ const Eigen::RowVectorXd &covariate) const {
+ if (!is_conjugate()) {
+ throw std::runtime_error(
+ "Call marg_lpdf() for a non-conjugate hierarchy");
+ } else {
+ throw std::runtime_error(
+ "marg_lpdf() not implemented for this hierarchy");
}
}
- static_cast(this)->sample_full_cond(true);
-}
+};
#endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_
diff --git a/src/hierarchies/conjugate_hierarchy.h b/src/hierarchies/conjugate_hierarchy.h
deleted file mode 100644
index 4d2430bea..000000000
--- a/src/hierarchies/conjugate_hierarchy.h
+++ /dev/null
@@ -1,206 +0,0 @@
-#ifndef BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_
-#define BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_
-
-#include "base_hierarchy.h"
-
-//! Template base class for conjugate hierarchy objects.
-
-//! This class acts as the base class for conjugate models, i.e. ones for which
-//! both the prior and posterior distribution have the same form
-//! (non-conjugate hierarchies should instead inherit directly from
-//! `BaseHierarchy`). This also means that the marginal distribution for the
-//! data is available in closed form. For this reason, each class deriving from
-//! this one must have a free method with one of the following signatures,
-//! based on whether it depends on covariates or not:
-//! double marg_lpdf(
-//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum,
-//! const Eigen::RowVectorXd &covariate) const;
-//! or
-//! double marg_lpdf(
-//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum) const;
-//! This returns the evaluation of the marginal distribution on the given data
-//! point (and covariate, if any), conditioned on the provided `Hyperparams`
-//! object. The latter may contain either prior or posterior values for
-//! hyperparameters, depending on where this function is called within the
-//! library.
-//! For more information, please refer to parent classes `AbstractHierarchy`
-//! and `BaseHierarchy`.
-
-template
-class ConjugateHierarchy
- : public BaseHierarchy {
- public:
- using BaseHierarchy::hypers;
- using BaseHierarchy::posterior_hypers;
- using BaseHierarchy::state;
-
- ConjugateHierarchy() = default;
- ~ConjugateHierarchy() = default;
-
- //! Public wrapper for `marg_lpdf()` methods
- virtual double get_marg_lpdf(
- const Hyperparams ¶ms, const Eigen::RowVectorXd &datum,
- const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const;
-
- //! Evaluates the log-prior predictive distribution of data in a single point
- //! @param datum Point which is to be evaluated
- //! @param covariate (Optional) covariate vector associated to datum
- //! @return The evaluation of the lpdf
- double prior_pred_lpdf(const Eigen::RowVectorXd &datum,
- const Eigen::RowVectorXd &covariate =
- Eigen::RowVectorXd(0)) const override {
- return get_marg_lpdf(*hypers, datum, covariate);
- }
-
- //! Evaluates the log-conditional predictive distr. of data in a single point
- //! @param datum Point which is to be evaluated
- //! @param covariate (Optional) covariate vector associated to datum
- //! @return The evaluation of the lpdf
- double conditional_pred_lpdf(const Eigen::RowVectorXd &datum,
- const Eigen::RowVectorXd &covariate =
- Eigen::RowVectorXd(0)) const override {
- return get_marg_lpdf(posterior_hypers, datum, covariate);
- }
-
- //! Evaluates the log-prior predictive distr. of data in a grid of points
- //! @param data Grid of points (by row) which are to be evaluated
- //! @param covariates (Optional) covariate vectors associated to data
- //! @return The evaluation of the lpdf
- virtual Eigen::VectorXd prior_pred_lpdf_grid(
- const Eigen::MatrixXd &data,
- const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0,
- 0)) const override;
-
- //! Evaluates the log-prior predictive distr. of data in a grid of points
- //! @param data Grid of points (by row) which are to be evaluated
- //! @param covariates (Optional) covariate vectors associated to data
- //! @return The evaluation of the lpdf
- virtual Eigen::VectorXd conditional_pred_lpdf_grid(
- const Eigen::MatrixXd &data,
- const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0,
- 0)) const override;
-
- //! Generates new state values from the centering posterior distribution
- //! @param update_params Save posterior hypers after the computation?
- void sample_full_cond(const bool update_params = true) override {
- if (this->card == 0) {
- // No posterior update possible
- static_cast(this)->sample_prior();
- } else {
- Hyperparams params =
- update_params
- ? static_cast(this)->compute_posterior_hypers()
- : posterior_hypers;
- state = static_cast(this)->draw(params);
- }
- }
-
- //! Saves posterior hyperparameters to the corresponding class member
- void save_posterior_hypers() {
- posterior_hypers =
- static_cast(this)->compute_posterior_hypers();
- }
-
- //! Returns whether the hierarchy represents a conjugate model or not
- bool is_conjugate() const override { return true; }
-
- protected:
- //! Evaluates the log-marginal distribution of data in a single point
- //! @param params Container of (prior or posterior) hyperparameter values
- //! @param datum Point which is to be evaluated
- //! @param covariate Covariate vector associated to datum
- //! @return The evaluation of the lpdf
- virtual double marg_lpdf(const Hyperparams ¶ms,
- const Eigen::RowVectorXd &datum,
- const Eigen::RowVectorXd &covariate) const {
- if (!this->is_dependent()) {
- throw std::runtime_error(
- "Cannot call marg_lpdf() from a non-dependent hierarchy");
- } else {
- throw std::runtime_error("marg_lpdf() not implemented");
- }
- }
-
- //! Evaluates the log-marginal distribution of data in a single point
- //! @param params Container of (prior or posterior) hyperparameter values
- //! @param datum Point which is to be evaluated
- //! @return The evaluation of the lpdf
- virtual double marg_lpdf(const Hyperparams ¶ms,
- const Eigen::RowVectorXd &datum) const {
- if (this->is_dependent()) {
- throw std::runtime_error(
- "Cannot call marg_lpdf() from a dependent hierarchy");
- } else {
- throw std::runtime_error("marg_lpdf() not implemented");
- }
- }
-};
-
-template
-double ConjugateHierarchy::get_marg_lpdf(
- const Hyperparams ¶ms, const Eigen::RowVectorXd &datum,
- const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const {
- if (this->is_dependent()) {
- return marg_lpdf(params, datum, covariate);
- } else {
- return marg_lpdf(params, datum);
- }
-}
-
-template
-Eigen::VectorXd
-ConjugateHierarchy::prior_pred_lpdf_grid(
- const Eigen::MatrixXd &data,
- const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const {
- Eigen::VectorXd lpdf(data.rows());
- if (covariates.cols() == 0) {
- // Pass null value as covariate
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->prior_pred_lpdf(
- data.row(i), Eigen::RowVectorXd(0));
- }
- } else if (covariates.rows() == 1) {
- // Use unique covariate
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->prior_pred_lpdf(
- data.row(i), covariates.row(0));
- }
- } else {
- // Use different covariates
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->prior_pred_lpdf(
- data.row(i), covariates.row(i));
- }
- }
- return lpdf;
-}
-
-template
-Eigen::VectorXd ConjugateHierarchy::
- conditional_pred_lpdf_grid(
- const Eigen::MatrixXd &data,
- const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const {
- Eigen::VectorXd lpdf(data.rows());
- if (covariates.cols() == 0) {
- // Pass null value as covariate
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->conditional_pred_lpdf(
- data.row(i), Eigen::RowVectorXd(0));
- }
- } else if (covariates.rows() == 1) {
- // Use unique covariate
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->conditional_pred_lpdf(
- data.row(i), covariates.row(0));
- }
- } else {
- // Use different covariates
- for (int i = 0; i < data.rows(); i++) {
- lpdf(i) = static_cast(this)->conditional_pred_lpdf(
- data.row(i), covariates.row(i));
- }
- }
- return lpdf;
-}
-
-#endif // BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_
diff --git a/src/hierarchies/fa_hierarchy.cc b/src/hierarchies/fa_hierarchy.cc
deleted file mode 100644
index 38a76c79d..000000000
--- a/src/hierarchies/fa_hierarchy.cc
+++ /dev/null
@@ -1,300 +0,0 @@
-#include "fa_hierarchy.h"
-
-#include
-
-#include
-#include
-#include
-#include
-#include
-
-#include "algorithm_state.pb.h"
-#include "hierarchy_prior.pb.h"
-#include "ls_state.pb.h"
-#include "src/utils/proto_utils.h"
-#include "src/utils/rng.h"
-
-double FAHierarchy::like_lpdf(const Eigen::RowVectorXd& datum) const {
- return bayesmix::multi_normal_lpdf_woodbury_chol(
- datum, state.mu, state.psi_inverse, state.cov_wood, state.cov_logdet);
-}
-
-FA::State FAHierarchy::draw(const FA::Hyperparams& params) {
- auto& rng = bayesmix::Rng::Instance().get();
- FA::State out;
- out.mu = params.mutilde;
- out.psi = params.beta / (params.alpha0 + 1.);
- out.eta = Eigen::MatrixXd::Zero(card, params.q);
- out.lambda = Eigen::MatrixXd::Zero(dim, params.q);
-
- for (size_t j = 0; j < dim; j++) {
- out.mu[j] =
- stan::math::normal_rng(params.mutilde[j], sqrt(params.phi), rng);
-
- out.psi[j] = stan::math::inv_gamma_rng(params.alpha0, params.beta[j], rng);
-
- for (size_t i = 0; i < params.q; i++) {
- out.lambda(j, i) = stan::math::normal_rng(0, 1, rng);
- }
- }
-
- for (size_t i = 0; i < card; i++) {
- for (size_t j = 0; j < params.q; j++) {
- out.eta(i, j) = stan::math::normal_rng(0, 1, rng);
- }
- }
-
- out.psi_inverse = out.psi.cwiseInverse().asDiagonal();
- compute_wood_factors(out.cov_wood, out.cov_logdet, out.lambda,
- out.psi_inverse);
-
- return out;
-}
-
-void FAHierarchy::initialize_state() {
- state.mu = hypers->mutilde;
- state.psi = hypers->beta / (hypers->alpha0 + 1.);
- state.eta = Eigen::MatrixXd::Zero(card, hypers->q);
- state.lambda = Eigen::MatrixXd::Zero(dim, hypers->q);
- state.psi_inverse = state.psi.cwiseInverse().asDiagonal();
- compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda,
- state.psi_inverse);
-}
-
-void FAHierarchy::initialize_hypers() {
- if (prior->has_fixed_values()) {
- // Set values
- hypers->mutilde = bayesmix::to_eigen(prior->fixed_values().mutilde());
- dim = hypers->mutilde.size();
- hypers->beta = bayesmix::to_eigen(prior->fixed_values().beta());
- hypers->phi = prior->fixed_values().phi();
- hypers->alpha0 = prior->fixed_values().alpha0();
- hypers->q = prior->fixed_values().q();
-
- // Automatic initialization
- if (dim == 0) {
- hypers->mutilde = dataset_ptr->colwise().mean();
- dim = hypers->mutilde.size();
- }
- if (hypers->beta.size() == 0) {
- Eigen::MatrixXd centered =
- dataset_ptr->rowwise() - dataset_ptr->colwise().mean();
- auto cov_llt = ((centered.transpose() * centered) /
- double(dataset_ptr->rows() - 1.))
- .llt();
- Eigen::MatrixXd precision_matrix(
- cov_llt.solve(Eigen::MatrixXd::Identity(dim, dim)));
- hypers->beta =
- (hypers->alpha0 - 1) * precision_matrix.diagonal().cwiseInverse();
- if (hypers->alpha0 == 1) {
- throw std::invalid_argument(
- "Scale parameter must be different than 1 when automatic "
- "initialization is used");
- }
- }
- // Check validity
- if (dim != hypers->beta.rows()) {
- throw std::invalid_argument(
- "Hyperparameters dimensions are not consistent");
- }
- for (size_t j = 0; j < dim; j++) {
- if (hypers->beta[j] <= 0) {
- throw std::invalid_argument("Shape parameter must be > 0");
- }
- }
- if (hypers->alpha0 <= 0) {
- throw std::invalid_argument("Scale parameter must be > 0");
- }
- if (hypers->phi <= 0) {
- throw std::invalid_argument("Diffusion parameter must be > 0");
- }
- if (hypers->q <= 0) {
- throw std::invalid_argument("Number of factors must be > 0");
- }
- }
-
- else {
- throw std::invalid_argument("Unrecognized hierarchy prior");
- }
-}
-
-void FAHierarchy::update_hypers(
- const std::vector& states) {
- auto& rng = bayesmix::Rng::Instance().get();
- if (prior->has_fixed_values()) {
- return;
- }
-
- else {
- throw std::invalid_argument("Unrecognized hierarchy prior");
- }
-}
-
-void FAHierarchy::update_summary_statistics(const Eigen::RowVectorXd& datum,
- const bool add) {
- if (add) {
- data_sum += datum;
- } else {
- data_sum -= datum;
- }
-}
-
-void FAHierarchy::clear_summary_statistics() {
- data_sum = Eigen::VectorXd::Zero(dim);
-}
-
-void FAHierarchy::set_state_from_proto(
- const google::protobuf::Message& state_) {
- auto& statecast = downcast_state(state_);
- state.mu = bayesmix::to_eigen(statecast.fa_state().mu());
- state.psi = bayesmix::to_eigen(statecast.fa_state().psi());
- state.eta = bayesmix::to_eigen(statecast.fa_state().eta());
- state.lambda = bayesmix::to_eigen(statecast.fa_state().lambda());
- state.psi_inverse = state.psi.cwiseInverse().asDiagonal();
- compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda,
- state.psi_inverse);
- set_card(statecast.cardinality());
-}
-
-std::shared_ptr
-FAHierarchy::get_state_proto() const {
- bayesmix::FAState state_;
- bayesmix::to_proto(state.mu, state_.mutable_mu());
- bayesmix::to_proto(state.psi, state_.mutable_psi());
- bayesmix::to_proto(state.eta, state_.mutable_eta());
- bayesmix::to_proto(state.lambda, state_.mutable_lambda());
-
- auto out = std::make_shared();
- out->mutable_fa_state()->CopyFrom(state_);
- return out;
-}
-
-void FAHierarchy::set_hypers_from_proto(
- const google::protobuf::Message& hypers_) {
- auto& hyperscast = downcast_hypers(hypers_).fa_state();
- hypers->mutilde = bayesmix::to_eigen(hyperscast.mutilde());
- hypers->alpha0 = hyperscast.alpha0();
- hypers->beta = bayesmix::to_eigen(hyperscast.beta());
- hypers->phi = hyperscast.phi();
- hypers->q = hyperscast.q();
-}
-
-std::shared_ptr
-FAHierarchy::get_hypers_proto() const {
- bayesmix::FAPriorDistribution hypers_;
- bayesmix::to_proto(hypers->mutilde, hypers_.mutable_mutilde());
- bayesmix::to_proto(hypers->beta, hypers_.mutable_beta());
- hypers_.set_alpha0(hypers->alpha0);
- hypers_.set_phi(hypers->phi);
- hypers_.set_q(hypers->q);
-
- auto out = std::make_shared();
- out->mutable_fa_state()->CopyFrom(hypers_);
- return out;
-}
-
-void FAHierarchy::sample_full_cond(const bool update_params /*= false*/) {
- if (this->card == 0) {
- // No posterior update possible
- sample_prior();
- } else {
- sample_eta();
- sample_mu();
- sample_psi();
- sample_lambda();
- }
-}
-
-void FAHierarchy::sample_eta() {
- auto& rng = bayesmix::Rng::Instance().get();
- auto sigma_eta_inv_llt =
- (Eigen::MatrixXd::Identity(hypers->q, hypers->q) +
- state.lambda.transpose() * state.psi_inverse * state.lambda)
- .llt();
- if (state.eta.rows() != card) {
- state.eta = Eigen::MatrixXd::Zero(card, state.eta.cols());
- }
- Eigen::MatrixXd temp_product(
- sigma_eta_inv_llt.solve(state.lambda.transpose() * state.psi_inverse));
- auto iterator = cluster_data_idx.begin();
- for (size_t i = 0; i < card; i++, iterator++) {
- Eigen::VectorXd tempvector(dataset_ptr->row(
- *iterator)); // TODO use slicing when Eigen is updated to v3.4
- state.eta.row(i) = (bayesmix::multi_normal_prec_chol_rng(
- temp_product * (tempvector - state.mu), sigma_eta_inv_llt, rng));
- }
-}
-
-void FAHierarchy::sample_mu() {
- auto& rng = bayesmix::Rng::Instance().get();
- Eigen::DiagonalMatrix sigma_mu;
-
- sigma_mu.diagonal() =
- (card * state.psi_inverse.diagonal().array() + hypers->phi)
- .cwiseInverse();
-
- Eigen::VectorXd sum = (state.eta.colwise().sum());
-
- Eigen::VectorXd mumean =
- sigma_mu * (hypers->phi * hypers->mutilde +
- state.psi_inverse * (data_sum - state.lambda * sum));
-
- state.mu = bayesmix::multi_normal_diag_rng(mumean, sigma_mu, rng);
-}
-
-void FAHierarchy::sample_lambda() {
- auto& rng = bayesmix::Rng::Instance().get();
-
- Eigen::MatrixXd temp_etateta(state.eta.transpose() * state.eta);
-
- for (size_t j = 0; j < dim; j++) {
- auto sigma_lambda_inv_llt =
- (Eigen::MatrixXd::Identity(hypers->q, hypers->q) +
- temp_etateta / state.psi[j])
- .llt();
- Eigen::VectorXd tempsum(card);
- const Eigen::VectorXd& data_col = dataset_ptr->col(j);
- auto iterator = cluster_data_idx.begin();
- for (size_t i = 0; i < card; i++, iterator++) {
- tempsum[i] = data_col(
- *iterator); // TODO use slicing when Eigen is updated to v3.4
- }
- tempsum = tempsum.array() - state.mu[j];
- tempsum = tempsum.array() / state.psi[j];
-
- state.lambda.row(j) = bayesmix::multi_normal_prec_chol_rng(
- sigma_lambda_inv_llt.solve(state.eta.transpose() * tempsum),
- sigma_lambda_inv_llt, rng);
- }
-}
-
-void FAHierarchy::sample_psi() {
- auto& rng = bayesmix::Rng::Instance().get();
-
- for (size_t j = 0; j < dim; j++) {
- double sum = 0;
- auto iterator = cluster_data_idx.begin();
- for (size_t i = 0; i < card; i++, iterator++) {
- sum += std::pow(
- ((*dataset_ptr)(*iterator, j) -
- state.mu[j] - // TODO use slicing when Eigen is updated to v3.4
- state.lambda.row(j).dot(state.eta.row(i))),
- 2);
- }
- state.psi[j] = stan::math::inv_gamma_rng(hypers->alpha0 + card / 2,
- hypers->beta[j] + sum / 2, rng);
- }
- state.psi_inverse = state.psi.cwiseInverse().asDiagonal();
- compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda,
- state.psi_inverse);
-}
-
-void FAHierarchy::compute_wood_factors(
- Eigen::MatrixXd& cov_wood, double& cov_logdet,
- const Eigen::MatrixXd& lambda,
- const Eigen::DiagonalMatrix& psi_inverse) {
- auto [cov_wood_, cov_logdet_] =
- bayesmix::compute_wood_chol_and_logdet(psi_inverse, lambda);
- cov_logdet = cov_logdet_;
- cov_wood = cov_wood_;
-}
diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h
index 8b6da31d4..fc7ce3079 100644
--- a/src/hierarchies/fa_hierarchy.h
+++ b/src/hierarchies/fa_hierarchy.h
@@ -1,134 +1,111 @@
#ifndef BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_
#define BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_
-#include
-
-#include
-#include
-#include
-
-#include "algorithm_state.pb.h"
#include "base_hierarchy.h"
#include "hierarchy_id.pb.h"
-#include "hierarchy_prior.pb.h"
+#include "likelihoods/fa_likelihood.h"
+#include "priors/fa_prior_model.h"
#include "src/utils/distributions.h"
-
-//! Mixture of Factor Analysers hierarchy for multivariate data.
-
-//! This class represents a hierarchical model where data are distributed
-
-namespace FA {
-//! Custom container for State values
-struct State {
- Eigen::VectorXd mu, psi;
- Eigen::MatrixXd eta, lambda, cov_wood;
- Eigen::DiagonalMatrix psi_inverse;
- double cov_logdet;
-};
-
-//! Custom container for Hyperparameters values
-struct Hyperparams {
- Eigen::VectorXd mutilde, beta;
- double phi, alpha0;
- unsigned int q;
-};
-
-}; // namespace FA
-
-class FAHierarchy : public BaseHierarchy {
+#include "updaters/fa_updater.h"
+
+/**
+ * Mixture of Factor Analysers hierarchy for multivariate data.
+ *
+ * This class represents a hierarchical model where data are distributed
+ * according to a multivariate Normal likelihood with a specific factorization
+ * of the covariance matrix (see the `FAHierarchy` class for details). The
+ * likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma
+ * centering distribution (see the `FAPriorModel` class for details). That is:
+ *
+ * \f[
+ * f(x_i \mid \mu, \Sigma, \Lambda) &= N(\mu, \Sigma + \Lambda \Lambda^T) \\
+ * \mu &\sim N_p(\tilde \mu, \psi I) \\
+ * \Lambda &\sim DL(\alpha) \\
+ * \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\
+ * \sigma^2_j &\sim IG(a,b) \quad j=1,...,p
+ * \f]
+ *
+ * where Lambda is the latent score matrix (size \f$ p \times d \f$
+ * with \f$ d << p \f$) and \f$ DL(\alpha) \f$ is the Laplace-Dirichlet
+ * distribution. See Bhattacharya et al. (2015) for further details
+ */
+
+class FAHierarchy
+ : public BaseHierarchy {
public:
FAHierarchy() = default;
~FAHierarchy() = default;
- //! Updates hyperparameter values given a vector of cluster states
- void update_hypers(const std::vector&
- states) override;
-
- //! Updates state values using the given (prior or posterior) hyperparameters
- FA::State draw(const FA::Hyperparams& params);
-
- //! Resets summary statistics for this cluster
- void clear_summary_statistics() override;
-
//! Returns the Protobuf ID associated to this class
bayesmix::HierarchyId get_id() const override {
return bayesmix::HierarchyId::FA;
}
- //! Read and set state values from a given Protobuf message
- void set_state_from_proto(const google::protobuf::Message& state_) override;
-
- //! Read and set hyperparameter values from a given Protobuf message
- void set_hypers_from_proto(
- const google::protobuf::Message& hypers_) override;
-
- //! Writes current state to a Protobuf message and return a shared_ptr
- //! New hierarchies have to first modify the field 'oneof val' in the
- //! AlgoritmState::ClusterState message by adding the appropriate type
- std::shared_ptr get_state_proto()
- const override;
-
- //! Writes current value of hyperparameters to a Protobuf message and
- //! return a shared_ptr.
- //! New hierarchies have to first modify the field 'oneof val' in the
- //! AlgoritmState::HierarchyHypers message by adding the appropriate type
- std::shared_ptr get_hypers_proto()
- const override;
-
- //! Returns whether the hierarchy models multivariate data or not
- bool is_multivariate() const override { return true; }
-
- //! Saves posterior hyperparameters to the corresponding class member
- void save_posterior_hypers() {
- // No Hyperprior present in the hierarchy
- }
-
- //! Generates new state values from the centering posterior distribution
- //! @param update_params Save posterior hypers after the computation?
- void sample_full_cond(const bool update_params = false) override;
-
- protected:
- //! Evaluates the log-likelihood of data in a single point
- //! @param datum Point which is to be evaluated
- //! @return The evaluation of the lpdf
- double like_lpdf(const Eigen::RowVectorXd& datum) const override;
-
- //! Updates cluster statistics when a datum is added or removed from it
- //! @param datum Data point which is being added or removed
- //! @param add Whether the datum is being added or removed
- void update_summary_statistics(const Eigen::RowVectorXd& datum,
- const bool add) override;
+ //! Sets the default updater algorithm for this hierarchy
+ void set_default_updater() { updater = std::make_shared(); }
//! Initializes state parameters to appropriate values
- void initialize_state() override;
-
- //! Initializes hierarchy hyperparameters to appropriate values
- void initialize_hypers() override;
-
- //! Gibbs sampling step for state variable eta
- void sample_eta();
-
- //! Gibbs sampling step for state variable mu
- void sample_mu();
-
- //! Gibbs sampling step for state variable psi
- void sample_psi();
-
- //! Gibbs sampling step for state variable lambda
- void sample_lambda();
-
- //! Helper function to compute factors needed for likelihood evaluation
- void compute_wood_factors(
- Eigen::MatrixXd& cov_wood, double& cov_logdet,
- const Eigen::MatrixXd& lambda,
- const Eigen::DiagonalMatrix& psi_inverse);
+ void initialize_state() override {
+ // Initialize likelihood dimension to prior one
+ like->set_dim(prior->get_dim());
+ // Get hypers and data dimension
+ auto hypers = prior->get_hypers();
+ unsigned int dim = like->get_dim();
+ // Initialize likelihood state
+ State::FA state;
+ state.mu = hypers.mutilde;
+ state.psi = hypers.beta / (hypers.alpha0 + 1.);
+ state.lambda = Eigen::MatrixXd::Zero(dim, hypers.q);
+ state.psi_inverse = state.psi.cwiseInverse().asDiagonal();
+ state.compute_wood_factors();
+ like->set_state(state);
+ }
+};
- //! Sum of data points currently belonging to the cluster
- Eigen::VectorXd data_sum;
+//! Empirical-Bayes hyperparameters initialization for the FAHierarchy.
+//! Sets the hyperparameters in `hier` starting from the data on which the user
+//! wants to fit the model.
+inline void set_fa_hyperparams_from_data(FAHierarchy* hier) {
+ // TODO test this function
+ auto dataset_ptr =
+ std::static_pointer_cast(hier->get_likelihood())
+ ->get_dataset();
+ auto hypers =
+ std::static_pointer_cast(hier->get_prior())->get_hypers();
+ unsigned int dim =
+ std::static_pointer_cast(hier->get_likelihood())
+ ->get_dim();
+
+ // Automatic initialization
+ if (dim == 0) {
+ hypers.mutilde = dataset_ptr->colwise().mean();
+ dim = hypers.mutilde.size();
+ }
+ if (hypers.beta.size() == 0) {
+ Eigen::MatrixXd centered =
+ dataset_ptr->rowwise() - dataset_ptr->colwise().mean();
+ auto cov_llt =
+ ((centered.transpose() * centered) / double(dataset_ptr->rows() - 1.))
+ .llt();
+ Eigen::MatrixXd precision_matrix(
+ cov_llt.solve(Eigen::MatrixXd::Identity(dim, dim)));
+ hypers.beta =
+ (hypers.alpha0 - 1) * precision_matrix.diagonal().cwiseInverse();
+ if (hypers.alpha0 == 1) {
+ throw std::invalid_argument(
+ "Scale parameter must be different than 1 when automatic "
+ "initialization is used");
+ }
+ }
- //! Number of variables for each datum
- size_t dim;
+ bayesmix::AlgorithmState::HierarchyHypers state;
+ bayesmix::to_proto(hypers.mutilde,
+ state.mutable_fa_state()->mutable_mutilde());
+ bayesmix::to_proto(hypers.beta, state.mutable_fa_state()->mutable_beta());
+ state.mutable_fa_state()->set_alpha0(hypers.alpha0);
+ state.mutable_fa_state()->set_phi(hypers.phi);
+ state.mutable_fa_state()->set_q(hypers.q);
+ hier->get_prior()->set_hypers_from_proto(state);
};
#endif // BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_
diff --git a/src/hierarchies/lapnig_hierarchy.cc b/src/hierarchies/lapnig_hierarchy.cc
deleted file mode 100644
index b0d479244..000000000
--- a/src/hierarchies/lapnig_hierarchy.cc
+++ /dev/null
@@ -1,215 +0,0 @@
-#include "lapnig_hierarchy.h"
-
-#include
-
-#include
-#include
-#include
-#include
-#include
-
-#include "algorithm_state.pb.h"
-#include "hierarchy_prior.pb.h"
-#include "ls_state.pb.h"
-#include "src/utils/rng.h"
-
-unsigned int LapNIGHierarchy::accepted_ = 0;
-unsigned int LapNIGHierarchy::iter_ = 0;
-
-void LapNIGHierarchy::set_state_from_proto(
- const google::protobuf::Message &state_) {
- auto &statecast = downcast_state(state_);
- state.mean = statecast.uni_ls_state().mean();
- state.scale = statecast.uni_ls_state().var();
- set_card(statecast.cardinality());
-}
-
-void LapNIGHierarchy::set_hypers_from_proto(
- const google::protobuf::Message &hypers_) {
- auto &hyperscast = downcast_hypers(hypers_).lapnig_state();
- hypers->mean = hyperscast.mean();
- hypers->var = hyperscast.var();
- hypers->scale = hyperscast.scale();
- hypers->shape = hyperscast.shape();
- hypers->mh_mean_var = hyperscast.mh_mean_var();
- hypers->mh_log_scale_var = hyperscast.mh_log_scale_var();
-}
-
-double LapNIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const {
- return stan::math::double_exponential_lpdf(datum(0), state.mean,
- state.scale);
-}
-
-std::shared_ptr
-LapNIGHierarchy::get_state_proto() const {
- bayesmix::UniLSState state_;
- state_.set_mean(state.mean);
- state_.set_var(state.scale);
-
- auto out = std::make_shared();
- out->mutable_uni_ls_state()->CopyFrom(state_);
- return out;
-}
-
-std::shared_ptr
-LapNIGHierarchy::get_hypers_proto() const {
- bayesmix::LapNIGState hypers_;
- hypers_.set_mean(hypers->mean);
- hypers_.set_var(hypers->var);
- hypers_.set_shape(hypers->shape);
- hypers_.set_scale(hypers->scale);
- hypers_.set_mh_mean_var(hypers->mh_mean_var);
- hypers_.set_mh_log_scale_var(hypers->mh_log_scale_var);
-
- auto out = std::make_shared();
- out->mutable_lapnig_state()->CopyFrom(hypers_);
- return out;
-}
-
-void LapNIGHierarchy::clear_summary_statistics() {
- cluster_data_values.clear();
- sum_abs_diff_curr = 0;
- sum_abs_diff_prop = 0;
-}
-
-void LapNIGHierarchy::initialize_hypers() {
- if (prior->has_fixed_values()) {
- // Set values
- hypers->mean = prior->fixed_values().mean();
- hypers->var = prior->fixed_values().var();
- hypers->shape = prior->fixed_values().shape();
- hypers->scale = prior->fixed_values().scale();
- hypers->mh_mean_var = prior->fixed_values().mh_mean_var();
- hypers->mh_log_scale_var = prior->fixed_values().mh_log_scale_var();
- // Check validity
- if (hypers->var <= 0) {
- throw std::invalid_argument("Variance-scaling parameter must be > 0");
- }
- if (hypers->shape <= 0) {
- throw std::invalid_argument("Shape parameter must be > 0");
- }
- if (hypers->scale <= 0) {
- throw std::invalid_argument("Scale parameter must be > 0");
- }
- if (hypers->mh_mean_var <= 0) {
- throw std::invalid_argument("mh_mean_var parameter must be > 0");
- }
- if (hypers->mh_log_scale_var <= 0) {
- throw std::invalid_argument("mh_log_scale_var parameter must be > 0");
- }
- } else {
- throw std::invalid_argument("Unrecognized hierarchy prior");
- }
-}
-
-void LapNIGHierarchy::initialize_state() {
- state.mean = hypers->mean;
- state.scale = hypers->scale / (hypers->shape + 1); // mode of Inv-Gamma
-}
-
-void LapNIGHierarchy::update_hypers(
- const std::vector &states) {
- auto &rng = bayesmix::Rng::Instance().get();
- if (prior->has_fixed_values()) {
- return;
- } else {
- throw std::invalid_argument("Unrecognized hierarchy prior");
- }
-}
-
-LapNIG::State LapNIGHierarchy::draw(const LapNIG::Hyperparams ¶ms) {
- auto &rng = bayesmix::Rng::Instance().get();
- LapNIG::State out{};
- out.scale = stan::math::inv_gamma_rng(params.shape, 1. / params.scale, rng);
- out.mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng);
- return out;
-}
-
-void LapNIGHierarchy::update_summary_statistics(
- const Eigen::RowVectorXd &datum, const bool add) {
- if (add) {
- sum_abs_diff_curr += std::abs(state.mean - datum(0, 0));
- cluster_data_values.push_back(datum);
- } else {
- sum_abs_diff_curr -= std::abs(state.mean - datum(0, 0));
- auto it = std::find(cluster_data_values.begin(), cluster_data_values.end(),
- datum);
- cluster_data_values.erase(it);
- }
-}
-
-void LapNIGHierarchy::sample_full_cond(const bool update_params /*= false*/) {
- if (this->card == 0) {
- // No posterior update possible
- this->sample_prior();
- } else {
- // Number of iterations to compute the acceptance rate
- ++iter_;
-
- // Random generator
- auto &rng = bayesmix::Rng::Instance().get();
-
- // Candidate mean and candidate log_scale
- Eigen::VectorXd curr_unc_params(2);
- curr_unc_params << state.mean, std::log(state.scale);
-
- Eigen::VectorXd prop_unc_params = propose_rwmh(curr_unc_params);
-
- double log_target_prop =
- eval_prior_lpdf_unconstrained(prop_unc_params) +
- eval_like_lpdf_unconstrained(prop_unc_params, false);
-
- double log_target_curr =
- eval_prior_lpdf_unconstrained(curr_unc_params) +
- eval_like_lpdf_unconstrained(curr_unc_params, true);
-
- double log_a_rate = log_target_prop - log_target_curr;
-
- if (std::log(stan::math::uniform_rng(0, 1, rng)) < log_a_rate) {
- ++accepted_;
- state.mean = prop_unc_params(0);
- state.scale = std::exp(prop_unc_params(1));
- sum_abs_diff_curr = sum_abs_diff_prop;
- }
- }
-}
-
-Eigen::VectorXd LapNIGHierarchy::propose_rwmh(
- const Eigen::VectorXd &curr_vals) {
- auto &rng = bayesmix::Rng::Instance().get();
- double candidate_mean =
- curr_vals(0) + stan::math::normal_rng(0, sqrt(hypers->mh_mean_var), rng);
- double candidate_log_scale =
- curr_vals(1) +
- stan::math::normal_rng(0, sqrt(hypers->mh_log_scale_var), rng);
- Eigen::VectorXd proposal(2);
- proposal << candidate_mean, candidate_log_scale;
- return proposal;
-}
-
-double LapNIGHierarchy::eval_prior_lpdf_unconstrained(
- const Eigen::VectorXd &unconstrained_parameters) {
- double mu = unconstrained_parameters(0);
- double log_scale = unconstrained_parameters(1);
- double scale = std::exp(log_scale);
- return stan::math::normal_lpdf(mu, hypers->mean, std::sqrt(hypers->var)) +
- stan::math::inv_gamma_lpdf(scale, hypers->shape, hypers->scale) +
- log_scale;
-}
-
-double LapNIGHierarchy::eval_like_lpdf_unconstrained(
- const Eigen::VectorXd &unconstrained_parameters, const bool is_current) {
- double mean = unconstrained_parameters(0);
- double log_scale = unconstrained_parameters(1);
- double scale = std::exp(log_scale);
- double diff_sum = 0; // Sum of absolute values of data - candidate_mean
- if (is_current) {
- diff_sum = sum_abs_diff_curr;
- } else {
- for (auto &elem : cluster_data_values) {
- diff_sum += std::abs(elem(0, 0) - mean);
- }
- sum_abs_diff_prop = diff_sum;
- }
- return std::log(0.5 / scale) + (-0.5 / scale * diff_sum);
-}
diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h
index cdb07e55b..01574da19 100644
--- a/src/hierarchies/lapnig_hierarchy.h
+++ b/src/hierarchies/lapnig_hierarchy.h
@@ -1,146 +1,57 @@
#ifndef BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_
#define BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_
-#include
-
-#include
-#include
-#include
-#include
-
-#include "algorithm_state.pb.h"
#include "base_hierarchy.h"
#include "hierarchy_id.pb.h"
-#include "hierarchy_prior.pb.h"
-
-//! Laplace Normal-InverseGamma hierarchy for univariate data.
-
-//! This class represents a hierarchical model where data are distributed
-//! according to a laplace likelihood, the parameters of which have a
-//! Normal-InverseGamma centering distribution. That is:
-//! f(x_i|mu,lambda) = Laplace(mu,lambda)
-//! (mu,lambda) ~ N-IG(mu0, lambda0, alpha0, beta0)
-//! The state is composed of mean and scale. The state hyperparameters,
-//! contained in the Hypers object, are (mu_0, lambda0, alpha0, beta0,
-//! scale_var, mean_var), all scalar values. Note that this hierarchy is NOT
-//! conjugate, thus the marginal distribution is not available in closed form.
-//! The hyperprameters scale_var and mean_var are used to perform a step of
-//! Random Walk Metropolis Hastings to sample from the full conditionals. For
-//! more information, please refer to parent classes: `AbstractHierarchy` and
-//! `BaseHierarchy`.
-
-namespace LapNIG {
-//! Custom container for State values
-struct State {
- double mean, scale;
-};
-
-//! Custom container for Hyperparameters values
-struct Hyperparams {
- double mean, var, shape, scale, mh_log_scale_var, mh_mean_var;
-};
-} // namespace LapNIG
+#include "likelihoods/laplace_likelihood.h"
+#include "priors/nxig_prior_model.h"
+#include "updaters/mala_updater.h"
+
+/**
+ * Laplace Normal-InverseGamma hierarchy for univariate data.
+ *
+ * This class represents a hierarchical model where data are distributed
+ * according to a Laplace likelihood (see the `LaplaceLikelihood` class for
+ * deatils). The likelihood parameters have a Normal x InverseGamma centering
+ * distribution (see the `NxIGPriorModel` class for details). That is:
+ *
+ * \f[
+ * f(x_i \mid \mu,\sigma^2) &= Laplace(\mu,\sqrt{\sigma^2/2})\\
+ * \mu &\sim N(\mu_0,\eta^2) \\
+ * \sigma^2 &\sim InvGamma(a, b)
+ * \f]
+ * The state is composed of mean and variance (thus the scale for the Laplace
+ * distribution is \f$ \sqrt{\sigma^2/2}) \f$. The state hyperparameters are
+ * \f$(mu_0, \sigma^2, a, b)\f$, all scalar values. Note that this hierarchy
+ * is NOT conjugate, thus the marginal distribution is not available in closed
+ * form.
+ */
class LapNIGHierarchy
- : public BaseHierarchy {
+ : public BaseHierarchy {
public:
LapNIGHierarchy() = default;
~LapNIGHierarchy() = default;
- //! Counters for tracking acceptance rate in MH step
- static unsigned int accepted_;
- static unsigned int iter_;
-
//! Returns the Protobuf ID associated to this class
bayesmix::HierarchyId get_id() const override {
return bayesmix::HierarchyId::LapNIG;
}
- //! Read and set state values from a given Protobuf message
- void set_state_from_proto(const google::protobuf::Message &state_) override;
-
- //! Read and set hyperparameter values from a given Protobuf message
- void set_hypers_from_proto(
- const google::protobuf::Message &hypers_) override;
-
- void save_posterior_hypers() {
- throw std::runtime_error("save_posterior_hypers() not implemented");
- }
-
- //! Returns whether the hierarchy models multivariate data or not
- bool is_multivariate() const override { return false; }
-
- //! Writes current state to a Protobuf message and return a shared_ptr
- //! New hierarchies have to first modify the field 'oneof val' in the
- //! AlgoritmState::ClusterState message by adding the appropriate type
- std::shared_ptr get_state_proto()
- const override;
-
- //! Writes current value of hyperparameters to a Protobuf message and
- //! return a shared_ptr.
- //! New hierarchies have to first modify the field 'oneof val' in the
- //! AlgoritmState::HierarchyHypers message by adding the appropriate type
- std::shared_ptr get_hypers_proto()
- const override;
-
- //! Resets summary statistics for this cluster
- void clear_summary_statistics() override;
-
- //! Updates hyperparameter values given a vector of cluster states
- void update_hypers(const std::vector
- &states) override;
-
- //! Updates state values using the given (prior or posterior) hyperparameters
- LapNIG::State draw(const LapNIG::Hyperparams ¶ms);
-
- //! Generates new state values from the centering posterior distribution
- //! @param update_params Save posterior hypers after the computation?
- void sample_full_cond(const bool update_params = false) override;
-
- protected:
- //! Set of values of data points belonging to this cluster
- std::list cluster_data_values;
-
- //! Sum of absolute differences for current params
- double sum_abs_diff_curr = 0;
-
- //! Sum of absolute differences for proposal params
- double sum_abs_diff_prop = 0;
-
- //! Samples from the proposal distribution using Random Walk
- //! Metropolis-Hastings
- Eigen::VectorXd propose_rwmh(const Eigen::VectorXd &curr_vals);
-
- //! Evaluates the prior given the mean (unconstrained_parameters(0))
- //! and log of the scale (unconstrained_parameters(1))
- double eval_prior_lpdf_unconstrained(
- const Eigen::VectorXd &unconstrained_parameters);
-
- //! Evaluates the (sum of the) log likelihood for all the observations in the
- //! cluster given the mean (unconstrained_parameters(0))
- //! and log of the scale (unconstrained_parameters(1)).
- //! The parameter "is_current" is used to identify if the evaluation of the
- //! likelihood is on the current or on the proposed parameters, in order to
- //! avoid repeating calculations of the sum of the absolute differences
- double eval_like_lpdf_unconstrained(
- const Eigen::VectorXd &unconstrained_parameters, const bool is_current);
-
- //! Evaluates the log-likelihood of data in a single point
- //! @param datum Point which is to be evaluated
- //! @return The evaluation of the lpdf
- double like_lpdf(const Eigen::RowVectorXd &datum) const override;
-
- //! Updates cluster statistics when a datum is added or removed from it
- //! @param datum Data point which is being added or removed
- //! @param add Whether the datum is being added or removed
- void update_summary_statistics(const Eigen::RowVectorXd &datum,
- const bool add) override;
-
- //! Initializes hierarchy hyperparameters to appropriate values
- void initialize_hypers() override;
+ //! Sets the default updater algorithm for this hierarchy
+ void set_default_updater() { updater = std::make_shared(); }
//! Initializes state parameters to appropriate values
- void initialize_state() override;
+ void initialize_state() override {
+ // Get hypers
+ auto hypers = prior->get_hypers();
+ // Initialize likelihood state
+ State::UniLS state;
+ state.mean = hypers.mean;
+ state.var = hypers.scale / (hypers.shape + 1);
+ like->set_state(state);
+ };
};
+
#endif // BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_
diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt
new file mode 100644
index 000000000..df10e8674
--- /dev/null
+++ b/src/hierarchies/likelihoods/CMakeLists.txt
@@ -0,0 +1,17 @@
+target_sources(bayesmix PUBLIC
+ likelihood_internal.h
+ abstract_likelihood.h
+ base_likelihood.h
+ uni_norm_likelihood.h
+ uni_norm_likelihood.cc
+ multi_norm_likelihood.h
+ multi_norm_likelihood.cc
+ uni_lin_reg_likelihood.h
+ uni_lin_reg_likelihood.cc
+ laplace_likelihood.h
+ laplace_likelihood.cc
+ fa_likelihood.h
+ fa_likelihood.cc
+)
+
+add_subdirectory(states)
diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h
new file mode 100644
index 000000000..38425829a
--- /dev/null
+++ b/src/hierarchies/likelihoods/abstract_likelihood.h
@@ -0,0 +1,190 @@
+#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_ABSTRACT_LIKELIHOOD_H_
+#define BAYESMIX_HIERARCHIES_LIKELIHOODS_ABSTRACT_LIKELIHOOD_H_
+
+#include