From 2bfb9605b1a05c38976a1ec2b3dcad65c9995643 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 20 Dec 2021 20:15:29 +0100 Subject: [PATCH 001/317] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6f5a0091e..b0e0a223b 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ To build the executable for the main file `run.cc`, please use the following lis ```shell mkdir build cd build -cmake .. -DDISABLE_DOCS -DDISABLE_BENCHMARKS -DDISABLE_TESTS +cmake .. -DDISABLE_DOCS=ON -DDISABLE_BENCHMARKS=ON -DDISABLE_TESTS=ON make run cd .. ``` From 8a346005eb19e23b47787b4a743568e07ffdfa1b Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 11:08:19 +0100 Subject: [PATCH 002/317] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 14c71bea1..ccdaa3a47 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ To build the executable for the main file `run_mcmc.cc`, please use the followin mkdir build cd build cmake .. -DDISABLE_DOCS=ON -DDISABLE_BENCHMARKS=ON -DDISABLE_TESTS=ON -make run +make run_mcmc cd .. ``` From 4ccbba321f25edbed31473c5213e95d212de7b9d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 11:10:30 +0100 Subject: [PATCH 003/317] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ccdaa3a47..a21e59c8d 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ where P is either the Dirichlet process or the Pitman--Yor process To install and use `bayesmix`, please `cd` to the folder to which you wish to install it, and clone this repository with the following command-line instruction: ```shell -git clone --recurse-submodule git@github.com:bayesmix-dev/bayesmix.git +git clone --recurse-submodules git@github.com:bayesmix-dev/bayesmix.git ``` Then, by using `cd bayesmix`, you will enter the newly downloaded folder. From f5f8400122de4db95bdbdd82971a07a8a21f8da5 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 21:19:31 +0100 Subject: [PATCH 004/317] Added NNxIG hierarchy --- src/hierarchies/nnxig_hierarchy.cc | 152 +++++++++++++++++++++++++++++ src/hierarchies/nnxig_hierarchy.h | 120 +++++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 src/hierarchies/nnxig_hierarchy.cc create mode 100644 src/hierarchies/nnxig_hierarchy.h diff --git a/src/hierarchies/nnxig_hierarchy.cc b/src/hierarchies/nnxig_hierarchy.cc new file mode 100644 index 000000000..f7bc62a16 --- /dev/null +++ b/src/hierarchies/nnxig_hierarchy.cc @@ -0,0 +1,152 @@ +#include "nnxig_hierarchy.h" + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "hierarchy_prior.pb.h" +#include "ls_state.pb.h" +#include "src/utils/rng.h" + +double NNxIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); +} + +void NNxIGHierarchy::initialize_state() { + state.mean = hypers->mean; + state.var = hypers->scale / (hypers->shape + 1); +} + +void NNxIGHierarchy::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mean = prior->fixed_values().mean(); + hypers->var = prior->fixed_values().var(); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); + + // Check validity + if (hypers->var <= 0) { + throw std::invalid_argument("Variance parameter must be > 0"); + } + if (hypers->shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers->scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NNxIGHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) { + if (add) { + data_sum += datum(0); + data_sum_squares += datum(0) * datum(0); + } else { + data_sum -= datum(0); + data_sum_squares -= datum(0) * datum(0); + } +} + +void NNxIGHierarchy::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NNxIGHierarchy::clear_summary_statistics() { + data_sum = 0; + data_sum_squares = 0; +} + +void NNxIGHierarchy::set_state_from_proto( + const google::protobuf::Message &state_) { + auto &statecast = downcast_state(state_); + state.mean = statecast.uni_ls_state().mean(); + state.var = statecast.uni_ls_state().var(); + set_card(statecast.cardinality()); +} + +void NNxIGHierarchy::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).nnxig_state(); + hypers->mean = hyperscast.mean(); + hypers->var = hyperscast.var(); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); +} + +std::shared_ptr +NNxIGHierarchy::get_state_proto() const { + bayesmix::UniLSState state_; + state_.set_mean(state.mean); + state_.set_var(state.var); + + auto out = std::make_shared(); + out->mutable_uni_ls_state()->CopyFrom(state_); + return out; +} + +std::shared_ptr +NNxIGHierarchy::get_hypers_proto() const { + bayesmix::NxIGDistribution hypers_; + hypers_.set_mean(hypers->mean); + hypers_.set_var(hypers->var); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); + + auto out = std::make_shared(); + out->mutable_nnxig_state()->CopyFrom(hypers_); + return out; +} + +void NNxIGHierarchy::sample_full_cond(bool update_params) { + if (this->card == 0) { + // No posterior update possible + sample_prior(); + } else { + NNxIG::Hyperparams params = + update_params ? compute_posterior_hypers() : posterior_hypers; + state = draw(params); + } +} + +NNxIG::State NNxIGHierarchy::draw(const NNxIG::Hyperparams ¶ms) { + auto &rng = bayesmix::Rng::Instance().get(); + NNxIG::State out; + out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); + out.mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); + return out; +} + +NNxIG::Hyperparams NNxIGHierarchy::compute_posterior_hypers() const { + // Initialize relevant variables + if (card == 0) { // no update possible + return *hypers; + } + // Compute posterior hyperparameters + NNxIG::Hyperparams post_params; + double var_y = data_sum_squares - 2 * state.mean * data_sum + + card * state.mean * state.mean; + post_params.mean = (hypers->var * data_sum + state.var * hypers->mean) / + (card * hypers->var + state.var); + post_params.var = + (state.var * hypers->var) / (card * hypers->var + state.var); + post_params.shape = hypers->shape + 0.5 * card; + post_params.scale = hypers->scale + 0.5 * var_y; + return post_params; +} + +void NNxIGHierarchy::save_posterior_hypers() { + posterior_hypers = compute_posterior_hypers(); +} diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h new file mode 100644 index 000000000..de5e878f6 --- /dev/null +++ b/src/hierarchies/nnxig_hierarchy.h @@ -0,0 +1,120 @@ +#ifndef BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "hierarchy_prior.pb.h" + +//! Non Conjugate Normal Normal-InverseGamma hierarchy for univariate data. + +//! This class represents a hierarchical model where data are distributed +//! according to a normal likelihood, the parameters of which have a +//! Normal-InverseGamma centering distribution. That is: +//! f(x_i|mu,sig) = N(mu,sig^2) +//! mu ~ N(mu0, sigma0) +//! sig^2 ~ IG(alpha0, beta0) +//! The state is composed of mean and variance. The state hyperparameters, +//! contained in the Hypers object, are (mu0, sigma0, alpha0, beta0), all +//! scalar values. Note that this hierarchy is non conjugate. + +namespace NNxIG { +//! Custom container for State values +struct State { + double mean, var; +}; + +//! Custom container for Hyperparameters values +struct Hyperparams { + double mean, var, shape, scale; +}; + +}; // namespace NNxIG + +class NNxIGHierarchy + : public BaseHierarchy { + public: + NNxIGHierarchy() = default; + ~NNxIGHierarchy() = default; + + //! Updates hyperparameter values given a vector of cluster states + void update_hypers(const std::vector + &states) override; + + //! Updates state values using the given (prior or posterior) hyperparameters + NNxIG::State draw(const NNxIG::Hyperparams ¶ms); + + //! Generates new state values from the centering posterior distribution + //! @param update_params Save posterior hypers after the computation? + void sample_full_cond(bool update_params = true) override; + + //! Saves posterior hyperparameters to the corresponding class member + void save_posterior_hypers(); + + //! Resets summary statistics for this cluster + void clear_summary_statistics() override; + + //! Returns the Protobuf ID associated to this class + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::NNxIG; + } + + //! Read and set state values from a given Protobuf message + void set_state_from_proto(const google::protobuf::Message &state_) override; + + //! Read and set hyperparameter values from a given Protobuf message + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type + std::shared_ptr get_state_proto() + const override; + + //! Writes current value of hyperparameters to a Protobuf message and + //! return a shared_ptr. + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::HierarchyHypers message by adding the appropriate type + std::shared_ptr get_hypers_proto() + const override; + + //! Computes and return posterior hypers given data currently in this cluster + NNxIG::Hyperparams compute_posterior_hypers() const; + + //! Returns whether the hierarchy models multivariate data or not + bool is_multivariate() const override { return false; } + + protected: + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + double like_lpdf(const Eigen::RowVectorXd &datum) const override; + + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param add Whether the datum is being added or removed + void update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) override; + + //! Initializes state parameters to appropriate values + void initialize_state() override; + + //! Initializes hierarchy hyperparameters to appropriate values + void initialize_hypers() override; + + //! Sum of data points currently belonging to the cluster + double data_sum = 0; + + //! Sum of squared data points currently belonging to the cluster + double data_sum_squares = 0; +}; + +#endif // BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ From 83d51050dccb3d89026377d3bf13c6bf002b7341 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 21:20:29 +0100 Subject: [PATCH 005/317] Ignore bayesmixpy.egg-info/ --- python/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/python/.gitignore b/python/.gitignore index 1ed1ae3e9..2bf689f94 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,3 +1,4 @@ __pycache__/ .ipynb_checkpoints/ *.csv +bayesmixpy.egg-info/ From ee3f9322f0bc844887f9e7238ca3c71b840df60a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 21:21:03 +0100 Subject: [PATCH 006/317] Added NNxIG hierarchy --- src/hierarchies/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index ae7cf29e7..55980c32c 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -9,4 +9,6 @@ target_sources(bayesmix nnig_hierarchy.cc nnw_hierarchy.h nnw_hierarchy.cc + nnxig_hierarchy.h + nnxig_hierarchy.cc ) From 49e035228ecb66860ee42d347882a853c1eb3319 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 21:21:38 +0100 Subject: [PATCH 007/317] Added NNxIG hierarchy --- src/hierarchies/load_hierarchies.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index 14ce2254d..d18f8e8d6 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -9,6 +9,7 @@ #include "lin_reg_uni_hierarchy.h" #include "nnig_hierarchy.h" #include "nnw_hierarchy.h" +#include "nnxig_hierarchy.h" #include "src/runtime/factory.h" //! Loads all available `Hierarchy` objects into the appropriate factory, so @@ -25,6 +26,9 @@ __attribute__((constructor)) static void load_hierarchies() { Builder NNIGbuilder = []() { return std::make_shared(); }; + Builder NNxIGbuilder = []() { + return std::make_shared(); + }; Builder NNWbuilder = []() { return std::make_shared(); }; @@ -33,6 +37,7 @@ __attribute__((constructor)) static void load_hierarchies() { }; factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder); + factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); } From e016cc9d2fee41be825f5ceb8bf6f104ef05c2aa Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 21:22:08 +0100 Subject: [PATCH 008/317] Added include for NNxIG hierarchy --- src/includes.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/includes.h b/src/includes.h index 4b9ba4e5a..077683610 100644 --- a/src/includes.h +++ b/src/includes.h @@ -13,6 +13,7 @@ #include "hierarchies/load_hierarchies.h" #include "hierarchies/nnig_hierarchy.h" #include "hierarchies/nnw_hierarchy.h" +#include "hierarchies/nnxig_hierarchy.h" #include "mixings/dirichlet_mixing.h" #include "mixings/load_mixings.h" #include "mixings/logit_sb_mixing.h" From 92ac1f87f6501a939c6e6d5c7cced3024f907f13 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 21:22:59 +0100 Subject: [PATCH 009/317] Added some tests for the NNxIG hierarchy --- test/hierarchies.cc | 58 +++++++++++++++++++++++++++++++++++++++++++++ test/priors.cc | 35 +++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/test/hierarchies.cc b/test/hierarchies.cc index e4862cf4a..2612eeb35 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -8,6 +8,7 @@ #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" #include "src/hierarchies/nnw_hierarchy.h" +#include "src/hierarchies/nnxig_hierarchy.h" #include "src/utils/proto_utils.h" #include "src/utils/rng.h" @@ -219,3 +220,60 @@ TEST(lin_reg_uni_hierarchy, misc) { ASSERT_GT(state.regression_coeffs(i), beta0(i)); } } + +TEST(nnxighierarchy, draw) { + auto hier = std::make_shared(); + bayesmix::NNxIGPrior prior; + double mu0 = 5.0; + double var0 = 1.0; + double alpha0 = 2.0; + double beta0 = 2.0; + prior.mutable_fixed_values()->set_mean(mu0); + prior.mutable_fixed_values()->set_var(var0); + prior.mutable_fixed_values()->set_shape(alpha0); + prior.mutable_fixed_values()->set_scale(beta0); + hier->get_mutable_prior()->CopyFrom(prior); + hier->initialize(); + + auto hier2 = hier->clone(); + hier2->sample_prior(); + + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); + + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} + +TEST(nnxighierarchy, sample_given_data) { + auto hier = std::make_shared(); + bayesmix::NNxIGPrior prior; + double mu0 = 5.0; + double var0 = 1.0; + double alpha0 = 2.0; + double beta0 = 2.0; + prior.mutable_fixed_values()->set_mean(mu0); + prior.mutable_fixed_values()->set_var(var0); + prior.mutable_fixed_values()->set_shape(alpha0); + prior.mutable_fixed_values()->set_scale(beta0); + hier->get_mutable_prior()->CopyFrom(prior); + + hier->initialize(); + + Eigen::VectorXd datum(1); + datum << 4.5; + + auto hier2 = hier->clone(); + hier2->add_datum(0, datum, false); + hier2->sample_full_cond(); + + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); + + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} diff --git a/test/priors.cc b/test/priors.cc index 6ad84a843..2b5893f27 100644 --- a/test/priors.cc +++ b/test/priors.cc @@ -6,6 +6,7 @@ #include "algorithm_state.pb.h" #include "src/hierarchies/nnig_hierarchy.h" #include "src/hierarchies/nnw_hierarchy.h" +#include "src/hierarchies/nnxig_hierarchy.h" #include "src/mixings/dirichlet_mixing.h" #include "src/utils/proto_utils.h" @@ -124,3 +125,37 @@ TEST(hierarchies, normal_mean_prior) { << std::endl; assert(mu00(0) < mean_out(0) && mu00(1) < mean_out(1)); } + +TEST(hierarchies, nxig_fixed_values) { + bayesmix::NNxIGPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var(1.0); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); + + auto hier = std::make_shared(); + hier->get_mutable_prior()->CopyFrom(prior); + hier->initialize(); + + std::vector> unique_values; + std::vector states; + + // Check equality before update + unique_values.push_back(hier); + for (size_t i = 1; i < 4; i++) { + unique_values.push_back(hier->clone()); + unique_values[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnxig_state().DebugString()); + } + + // Check equality after update + unique_values[0]->update_hypers(states); + unique_values[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + unique_values[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnxig_state().DebugString()); + } +} From 0393cce287c7bc57c6d1aaf8e61e23d350900514 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 21:23:32 +0100 Subject: [PATCH 010/317] src/utils/cluster_utils.cc Uniformed console output --- src/utils/cluster_utils.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/utils/cluster_utils.cc b/src/utils/cluster_utils.cc index d1c63503c..6f43200b6 100644 --- a/src/utils/cluster_utils.cc +++ b/src/utils/cluster_utils.cc @@ -35,6 +35,7 @@ Eigen::VectorXi bayesmix::cluster_estimate( std::cout << "Done)" << std::endl; // Compute Frobenius norm error of all iterations + std::cout << "Computing Frobenius norm error... " << std::endl; Eigen::VectorXd errors(n_iter); for (int k = 0; k < n_iter; k++) { for (int i = 0; i < n_data; i++) { @@ -48,6 +49,7 @@ Eigen::VectorXi bayesmix::cluster_estimate( bar.display(); } bar.done(); + std::cout << "Done" << std::endl; // Print Ending Message // Find iteration with the least error std::ptrdiff_t ibest; From a865c7aebf812537b6520d63f5ba50a18255effc Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 27 Dec 2021 21:28:36 +0100 Subject: [PATCH 011/317] Added test notebook for NNxIG hierarchy --- python/notebooks/gaussian_mix_NNxIG.ipynb | 133 ++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 python/notebooks/gaussian_mix_NNxIG.ipynb diff --git a/python/notebooks/gaussian_mix_NNxIG.ipynb b/python/notebooks/gaussian_mix_NNxIG.ipynb new file mode 100644 index 000000000..9e92c562e --- /dev/null +++ b/python/notebooks/gaussian_mix_NNxIG.ipynb @@ -0,0 +1,133 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "49d3291e", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"BAYESMIX_EXE\"] = \"/home/m_gianella/Documents/GitHub/bayesmix/build/run_mcmc\"\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "from bayesmixpy import run_mcmc\n", + "from tensorflow_probability.substrates import numpy as tfp\n", + "tfd = tfp.distributions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64a83071", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(123)\n", + "\n", + "# Set true parameters\n", + "N = 500\n", + "Ncomp = 3\n", + "means = [-5.0, 0.0, 5.0]\n", + "sds = [0.5, 2.0, 0.25]\n", + "weights = np.ones(Ncomp)/Ncomp\n", + "\n", + "cluster_allocs = tfd.Categorical(probs=weights).sample(N)\n", + "data = np.stack([tfd.Normal(means[cluster_allocs[i]], sds[cluster_allocs[i]]).sample() for i in range(N)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13df394d", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup parameters for bayesmixpy\n", + "hier_params = \\\n", + "\"\"\"\n", + "fixed_values {\n", + " mean: 0.0\n", + " var: 10.0\n", + " shape: 2.0\n", + " scale: 2.0\n", + "}\n", + "\"\"\"\n", + "\n", + "mix_params = \\\n", + "\"\"\"\n", + "dp_prior {\n", + " totalmass: 1\n", + "}\n", + "num_components: 3\n", + "\"\"\"\n", + "\n", + "algo_params = \\\n", + "\"\"\"\n", + "algo_id: \"BlockedGibbs\"\n", + "rng_seed: 20201124\n", + "iterations: 2000\n", + "burnin: 1000\n", + "init_num_clusters: 3\n", + "\"\"\"\n", + "\n", + "dens_grid = np.linspace(-7.5,7.5,1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12505b6e", + "metadata": {}, + "outputs": [], + "source": [ + "# Fit model using bayesmixpy\n", + "eval_dens, n_clus, clus_chain, best_clus = run_mcmc(\"NNxIG\",\"TruncSB\", data,\n", + " hier_params, mix_params, algo_params,\n", + " dens_grid, return_num_clusters=False,\n", + " return_clusters=False, return_best_clus=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1eb6c0e9", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot mean posterior density\n", + "plt.plot(dens_grid, np.exp(eval_dens.mean(axis=0)))\n", + "plt.hist(data, alpha=0.4, density=True)\n", + "for c in np.unique(best_clus):\n", + " data_in_clus = data[best_clus == c]\n", + " plt.scatter(data_in_clus, np.zeros_like(data_in_clus) + 0.01)\n", + "plt.title(\"Posterior estimated density\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From b17802ac62c9d5d9255259a3b2d578e5a8400725 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 31 Dec 2021 15:41:49 +0100 Subject: [PATCH 012/317] Update proto files related to NNxIG hierarchy --- proto/algorithm_state.proto | 1 + proto/distribution.proto | 12 ++++++++++++ proto/hierarchy_id.proto | 1 + proto/hierarchy_prior.proto | 11 ++++++++++- 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/proto/algorithm_state.proto b/proto/algorithm_state.proto index ff2ad76a6..3305e2a6e 100644 --- a/proto/algorithm_state.proto +++ b/proto/algorithm_state.proto @@ -41,6 +41,7 @@ message AlgorithmState { NIGDistribution nnig_state = 2; NWDistribution nnw_state = 3; MultiNormalIGDistribution lin_reg_uni_state = 4; + NxIGDistribution nnxig_state = 5; } } HierarchyHypers hierarchy_hypers = 5; // The current values of the hyperparameters of the hierarchy diff --git a/proto/distribution.proto b/proto/distribution.proto index e27e89a29..c123765a9 100644 --- a/proto/distribution.proto +++ b/proto/distribution.proto @@ -57,6 +57,18 @@ message NIGDistribution { double scale = 4; } +/* + * Parameters of a Normal x Inverse-Gamma distribution + * with density + * f(x, y) = N(x | mu, var) * IG(y | shape, scale) + */ +message NxIGDistribution { + double mean = 1; + double var = 2; + double shape = 3; + double scale = 4; +} + /* * Parameters of a Normal Wishart distribution * with density diff --git a/proto/hierarchy_id.proto b/proto/hierarchy_id.proto index 0ebc8be77..b07bfcc29 100644 --- a/proto/hierarchy_id.proto +++ b/proto/hierarchy_id.proto @@ -10,4 +10,5 @@ enum HierarchyId { NNIG = 1; // Normal - Normal Inverse Gamma NNW = 2; // Normal - Normal Wishart LinRegUni = 3; // Linear Regression (univariate response) + NNxIG = 4; // Normal - Normal x Inverse Gamma } diff --git a/proto/hierarchy_prior.proto b/proto/hierarchy_prior.proto index 13de7c614..3b5ce6187 100644 --- a/proto/hierarchy_prior.proto +++ b/proto/hierarchy_prior.proto @@ -31,6 +31,16 @@ message NNIGPrior { } } +/* + * Prior for the parameters of the base measure in a Normal-Normal x Inverse Gamma hierarchy + */ +message NNxIGPrior { + + oneof prior { + NxIGDistribution fixed_values = 1; // no prior, just fixed values + } +} + /* * Prior for the parameters of the base measure in a Normal-Normal Wishart hierarchy */ @@ -57,7 +67,6 @@ message NNWPrior { } } - /* * Prior for the parameters of the base measure in a Normal mixture model with a covariate-dependent * location. From 87e52cb9d77a3f0f49f6e55560e640227cd390a5 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 11 Jan 2022 17:02:52 +0100 Subject: [PATCH 013/317] Definition of currently implementes states --- src/hierarchies/likelihoods/states.h | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/hierarchies/likelihoods/states.h diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states.h new file mode 100644 index 000000000..73a368706 --- /dev/null +++ b/src/hierarchies/likelihoods/states.h @@ -0,0 +1,25 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ + +#include + +namespace State { + +struct UniLS { + double mean, var; +}; + +struct MultiLS { + Eigen::VectorXd mean; + Eigen::MatrixXd prec, prec_chol; + double prec_logdet; +}; + +struct UniLinReg { + Eigen::VectorXd regression_coeffs; + double var; +}; + +} // namespace State + +#endif From c2e41a10a6e4b7425309584dfb7e848004148614 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 11 Jan 2022 17:03:32 +0100 Subject: [PATCH 014/317] API for likelihood class --- .../likelihoods/abstract_likelihood.h | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 src/hierarchies/likelihoods/abstract_likelihood.h diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h new file mode 100644 index 000000000..0eb420eed --- /dev/null +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -0,0 +1,94 @@ +#ifndef BAYESMIX_HIERARCHIES_ABSTRACT_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_ABSTRACT_LIKELIHOOD_H_ + +#include + +#include +#include + +// #include +// #include +// #include + +#include "algorithm_state.pb.h" +// #include "hierarchy_id.pb.h" +// #include "src/utils/rng.h" + +class AbstractLikelihood { + public: + virtual ~AbstractLikelihood() = default; + + virtual std::shared_ptr clone() const = 0; + + double lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + if (is_dependent()) { + return compute_lpdf(datum, covariate); + } else { + return compute_lpdf(datum); + } + } + + virtual Eigen::VectorXd lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const = 0; + + virtual bool is_multivariate() const = 0; + + virtual bool is_dependent() const { return false; } + + virtual void set_state_from_proto( + const google::protobuf::Message &state_) = 0; + + void update_sum_stats(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, bool add) { + if (is_dependent()) { + return update_summary_statistics(datum, covariate, add); + } else { + return update_summary_statistics(datum, add); + } + } + + protected: + virtual double compute_lpdf(const Eigen::RowVectorXd &datum) const { + if (is_dependent()) { + throw std::runtime_error( + "Cannot call this function from a dependent likelihood"); + } else { + throw std::runtime_error("Not implemented"); + } + } + + virtual double compute_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + if (!is_dependent()) { + throw std::runtime_error( + "Cannot call this function from a non-dependent likelihood"); + } else { + throw std::runtime_error("Not implemented"); + } + } + + virtual void update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) { + if (is_dependent()) { + throw std::runtime_error( + "Cannot call this function from a dependent hierarchy"); + } else { + throw std::runtime_error("Not implemented"); + } + } + + virtual void update_summary_statistics(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) { + if (!is_dependent()) { + throw std::runtime_error( + "Cannot call this function from a non-dependent hierarchy"); + } else { + throw std::runtime_error("Not implemented"); + } + } +}; + +#endif From 6174d379e2791c8ec68e42c3a63200dcd1a883b9 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 11 Jan 2022 17:04:05 +0100 Subject: [PATCH 015/317] CRTP father class for likelihoods --- src/hierarchies/likelihoods/base_likelihood.h | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 src/hierarchies/likelihoods/base_likelihood.h diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h new file mode 100644 index 000000000..7db8d4674 --- /dev/null +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -0,0 +1,146 @@ +#ifndef BAYESMIX_HIERARCHIES_BASE_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_BASE_LIKELIHOOD_H_ + +#include + +#include +#include +// #include +#include +// #include + +#include "abstract_likelihood.h" +#include "algorithm_state.pb.h" + +template +class BaseLikelihood : public AbstractLikelihood { + public: + BaseLikelihood() = default; + + ~BaseLikelihood() = default; + + virtual std::shared_ptr clone() const override { + auto out = std::make_shared(static_cast(*this)); + out->clear_data(); + out->clear_summary_statistics(); + return out; + } + + virtual Eigen::VectorXd lpdf_grid(const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = + Eigen::MatrixXd(0, 0)) const override; + + int get_card() const override { return card; } + + double get_log_card() const override { return log_card; } + + std::set get_data_idx() const override { return cluster_data_idx; } + + void write_state_to_proto(google::protobuf::Message *out) const override; + + State get_state() const { return state; } + + void add_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + + void remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + + protected: + void set_card(const int card_) { + card = card_; + log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); + } + + virtual std::shared_ptr + get_state_proto() const = 0; + + void clear_data() { + set_card(0); + cluster_data_idx = std::set(); + } + + virtual void clear_summary_statistics() = 0; + + bayesmix::AlgorithmState::ClusterState *downcast_state( + google::protobuf::Message *state_) const { + return google::protobuf::internal::down_cast< + bayesmix::AlgorithmState::ClusterState *>(state_); + } + + const bayesmix::AlgorithmState::ClusterState &downcast_state( + const google::protobuf::Message &state_) const { + return google::protobuf::internal::down_cast< + const bayesmix::AlgorithmState::ClusterState &>(state_); + } + + State state; + + int card; + + int log_card; + + std::set cluster_data_idx; +}; + +template +void BaseLikelihood::add_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) { + assert(cluster_data_idx.find(id) == cluster_data_idx.end()); + card += 1; + log_card = std::log(card); + static_cast(this)->update_sum_stats(datum, covariate, true); + cluster_data_idx.insert(id); +} + +template +void BaseLikelihood::remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) { + static_cast(this)->update_sum_stats(datum, covariate, false); + set_card(card - 1); + auto it = cluster_data_idx.find(id); + assert(it != cluster_data_idx.end()); + cluster_data_idx.erase(it); +} + +template +void BaseLikelihood::write_state_to_proto( + google::protobuf::Message *out) const { + std::shared_ptr state_ = + get_state_proto(); + auto *out_cast = downcast_state(out); + out_cast->CopyFrom(*state_.get()); + out_cast->set_cardinality(card); +} + +template +Eigen::VectorXd BaseLikelihood::lpdf_grid( + const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { + Eigen::VectorXd lpdf(data.rows()); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->lpdf( + data.row(i), Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->lpdf(data.row(i), + covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->lpdf(data.row(i), + covariates.row(i)); + } + } + return lpdf; +} + +#endif From ad65aec81b5e54b86e91a1816387b318891a65a1 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 11 Jan 2022 17:04:36 +0100 Subject: [PATCH 016/317] Definitions of currently implemented hyper-parameters structs --- src/hierarchies/priors/hyperparams.h | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 src/hierarchies/priors/hyperparams.h diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h new file mode 100644 index 000000000..aaddb6c4a --- /dev/null +++ b/src/hierarchies/priors/hyperparams.h @@ -0,0 +1,30 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORMODEL_HYPERPARAMS_H_ +#define BAYESMIX_HIERARCHIES_PRIORMODEL_HYPERPARAMS_H_ + +#include + +namespace Hyperparams { + +struct NIG { + double mean, var_scaling, shape, scale; +}; + +struct NxIG { + double mean, var, shape, scale; +}; + +struct NW { + Eigen::VectorXd mean; + double var_scaling, deg_free; + Eigen::MatrixXd scale, scale_inv, scale_chol; +}; + +struct MNIG { + Eigen::VectorXd mean; + Eigen::MatrixXd var_scaling, var_scaling_inv; + double shape, scale; +}; + +} // namespace Hyperparams + +#endif From 023b0123134cdc8ed04ec6979a2044a9b33524f0 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 11 Jan 2022 17:05:05 +0100 Subject: [PATCH 017/317] API for prior model class --- src/hierarchies/priors/abstract_prior_model.h | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 src/hierarchies/priors/abstract_prior_model.h diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h new file mode 100644 index 000000000..e176d23bf --- /dev/null +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -0,0 +1,34 @@ +#ifndef BAYESMIX_HIERARCHIES_ABSTRACT_PRIORMODEL_H_ +#define BAYESMIX_HIERARCHIES_ABSTRACT_PRIORMODEL_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "src/utils/rng.h" + +class AbstractPriorModel { + public: + virtual ~AbstractPriorModel() = default; + + virtual double lpdf() = 0; + + virtual void sample_prior() = 0; + + virtual void update_hypers( + const std::vector &states) = 0; + + virtual void initialize_hypers() = 0; + + virtual google::protobuf::Message *get_mutable_prior() = 0; + + virtual void set_hypers_from_proto( + const google::protobuf::Message &state_) = 0; + + virtual void write_hypers_to_proto(google::protobuf::Message *out) const = 0; +}; + +#endif From 1ba20cdc25a09c581a459f7ed1488d3bf59de5c7 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 00:18:31 +0100 Subject: [PATCH 018/317] Add likelihoods/ source directory --- src/hierarchies/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index 55980c32c..567a6cd57 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -12,3 +12,5 @@ target_sources(bayesmix nnxig_hierarchy.h nnxig_hierarchy.cc ) + +add_subdirectory(likelihoods) From 9c6648ed37eba4ee018606030349252c8632e618 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 00:19:31 +0100 Subject: [PATCH 019/317] Add files to bayesmix target --- src/hierarchies/likelihoods/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 src/hierarchies/likelihoods/CMakeLists.txt diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt new file mode 100644 index 000000000..216b0c849 --- /dev/null +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -0,0 +1,7 @@ +target_sources(bayesmix + PUBLIC + abstract_likelihood.h + base_likelihood.h + uni_norm_likelihood.h + uni_norm_likelihood.cc +) From bdc2edbbb9ec911d9b3b2ed505788d786eaec31a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 00:20:11 +0100 Subject: [PATCH 020/317] Test for likelihood objects --- test/CMakeLists.txt | 25 +++++++++++++------------ test/likelihoods.cc | 29 +++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 12 deletions(-) create mode 100644 test/likelihoods.cc diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index dd2a016b1..62d34fde3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,18 +16,19 @@ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) add_executable(test_bayesmix $ - write_proto.cc - proto_utils.cc - hierarchies.cc - lpdf.cc - priors.cc - eigen_utils.cc - distributions.cc - semi_hdp.cc - collectors.cc - runtime.cc - rng.cc - logit_sb.cc + # write_proto.cc + # proto_utils.cc + # hierarchies.cc + # lpdf.cc + # priors.cc + # eigen_utils.cc + # distributions.cc + # semi_hdp.cc + # collectors.cc + # runtime.cc + # rng.cc + # logit_sb.cc + likelihoods.cc ) target_include_directories(test_bayesmix PUBLIC ${INCLUDE_PATHS}) diff --git a/test/likelihoods.cc b/test/likelihoods.cc new file mode 100644 index 000000000..7d329f34d --- /dev/null +++ b/test/likelihoods.cc @@ -0,0 +1,29 @@ +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "ls_state.pb.h" +#include "src/hierarchies/likelihoods/uni_norm_likelihood.h" +#include "src/utils/rng.h" + +TEST(uni_norm_likelihood, state_setget) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + state_.set_mean(5.23); + state_.set_var(1.02); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Get state proto + auto out = like->get_state_proto(); + + // Check if coincides + ASSERT_EQ(out->DebugString(), clust_state_.DebugString()); +} From 63af6d5b4972d3557edc1f5007144a2d0a5517a2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 00:21:33 +0100 Subject: [PATCH 021/317] Updates and bug fix --- src/hierarchies/likelihoods/abstract_likelihood.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 0eb420eed..9e27a967e 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -35,11 +35,14 @@ class AbstractLikelihood { virtual bool is_multivariate() const = 0; - virtual bool is_dependent() const { return false; } + virtual bool is_dependent() const = 0; virtual void set_state_from_proto( const google::protobuf::Message &state_) = 0; + virtual std::shared_ptr + get_state_proto() const = 0; + void update_sum_stats(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate, bool add) { if (is_dependent()) { @@ -49,6 +52,8 @@ class AbstractLikelihood { } } + virtual void clear_summary_statistics() = 0; + protected: virtual double compute_lpdf(const Eigen::RowVectorXd &datum) const { if (is_dependent()) { From 5dbc5bcda6cbc7ac1738dd1a4631d3fd78f163e5 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 00:21:58 +0100 Subject: [PATCH 022/317] Updates and bug fix --- src/hierarchies/likelihoods/base_likelihood.h | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 7db8d4674..ccf505c2d 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -16,10 +16,9 @@ template class BaseLikelihood : public AbstractLikelihood { public: BaseLikelihood() = default; - ~BaseLikelihood() = default; - virtual std::shared_ptr clone() const override { + virtual std::shared_ptr clone() const override { auto out = std::make_shared(static_cast(*this)); out->clear_data(); out->clear_summary_statistics(); @@ -30,23 +29,24 @@ class BaseLikelihood : public AbstractLikelihood { const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const override; - int get_card() const override { return card; } + int get_card() const { return card; } - double get_log_card() const override { return log_card; } + double get_log_card() const { return log_card; } - std::set get_data_idx() const override { return cluster_data_idx; } + std::set get_data_idx() const { return cluster_data_idx; } - void write_state_to_proto(google::protobuf::Message *out) const override; + void write_state_to_proto( + google::protobuf::Message *out) const; // override; State get_state() const { return state; } - void add_datum( - const int id, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + void add_datum(const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)); // override; - void remove_datum( - const int id, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + void remove_datum(const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)); // override; protected: void set_card(const int card_) { @@ -54,16 +54,11 @@ class BaseLikelihood : public AbstractLikelihood { log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); } - virtual std::shared_ptr - get_state_proto() const = 0; - void clear_data() { set_card(0); cluster_data_idx = std::set(); } - virtual void clear_summary_statistics() = 0; - bayesmix::AlgorithmState::ClusterState *downcast_state( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< From b4db757320d88976b7ed856301f3f08366c38e7f Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 00:22:26 +0100 Subject: [PATCH 023/317] Create univariate normal likelihood class --- .../likelihoods/uni_norm_likelihood.cc | 0 .../likelihoods/uni_norm_likelihood.h | 53 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 src/hierarchies/likelihoods/uni_norm_likelihood.cc create mode 100644 src/hierarchies/likelihoods/uni_norm_likelihood.h diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc new file mode 100644 index 000000000..e69de29bb diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h new file mode 100644 index 000000000..cb56b6e0d --- /dev/null +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -0,0 +1,53 @@ +#ifndef BAYESMIX_HIERARCHIES_UNI_NORM_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_UNI_NORM_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states.h" + +class UniNormLikelihood + : public BaseLikelihood { + private: + double data_sum = 0; + double data_sum_squares = 0; + + public: + UniNormLikelihood() = default; + + ~UniNormLikelihood() = default; + + bool is_multivariate() const override { return false; }; + + bool is_dependent() const override { return false; }; + + void set_state_from_proto(const google::protobuf::Message &state_) override { + auto &statecast = downcast_state(state_); + state.mean = statecast.uni_ls_state().mean(); + state.var = statecast.uni_ls_state().var(); + set_card(statecast.cardinality()); + }; + + std::shared_ptr get_state_proto() + const override { + bayesmix::UniLSState state_; + state_.set_mean(state.mean); + state_.set_var(state.var); + + auto out = std::make_shared(); + out->mutable_uni_ls_state()->CopyFrom(state_); + return out; + }; + + void clear_summary_statistics() override { + data_sum = 0; + data_sum_squares = 0; + } +}; + +#endif From 5fe9bf254ac3d0e33d667a7dbbd70a37eb791b6f Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 17:41:22 +0100 Subject: [PATCH 024/317] Move methods' definitions to source file --- .../likelihoods/uni_norm_likelihood.cc | 39 +++++++++++++++++++ .../likelihoods/uni_norm_likelihood.h | 37 +++++------------- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index e69de29bb..9cdfb70aa 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -0,0 +1,39 @@ +#include "uni_norm_likelihood.h" + +double UniNormLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); +} + +void UniNormLikelihood::update_summary_statistics( + const Eigen::RowVectorXd &datum, bool add) { + if (add) { + data_sum += datum(0); + data_sum_squares += datum(0) * datum(0); + } else { + data_sum -= datum(0); + data_sum_squares -= datum(0) * datum(0); + } +} + +void UniNormLikelihood::set_state_from_proto( + const google::protobuf::Message &state_) { + auto &statecast = downcast_state(state_); + state.mean = statecast.uni_ls_state().mean(); + state.var = statecast.uni_ls_state().var(); + set_card(statecast.cardinality()); +} + +std::shared_ptr +UniNormLikelihood::get_state_proto() const { + bayesmix::UniLSState state_; + state_.set_mean(state.mean); + state_.set_var(state.var); + auto out = std::make_shared(); + out->mutable_uni_ls_state()->CopyFrom(state_); + return out; +} + +void UniNormLikelihood::clear_summary_statistics() { + data_sum = 0; + data_sum_squares = 0; +} diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index cb56b6e0d..e732b9058 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -5,6 +5,7 @@ #include #include +#include #include #include "algorithm_state.pb.h" @@ -13,41 +14,23 @@ class UniNormLikelihood : public BaseLikelihood { - private: - double data_sum = 0; - double data_sum_squares = 0; - public: UniNormLikelihood() = default; - ~UniNormLikelihood() = default; - bool is_multivariate() const override { return false; }; - bool is_dependent() const override { return false; }; - - void set_state_from_proto(const google::protobuf::Message &state_) override { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.var = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); - }; - + void set_state_from_proto(const google::protobuf::Message &state_) override; std::shared_ptr get_state_proto() - const override { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.var); + const override; + void clear_summary_statistics() override; - auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); - return out; - }; + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) override; - void clear_summary_statistics() override { - data_sum = 0; - data_sum_squares = 0; - } + double data_sum = 0; + double data_sum_squares = 0; }; #endif From 2f50e8c54294aea8516ac60c6a1c4981f928cfb5 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 17:42:10 +0100 Subject: [PATCH 025/317] Test whole API for UniNormLikelihood class --- test/likelihoods.cc | 46 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/test/likelihoods.cc b/test/likelihoods.cc index 7d329f34d..5a6447e4d 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -24,6 +24,50 @@ TEST(uni_norm_likelihood, state_setget) { // Get state proto auto out = like->get_state_proto(); - // Check if coincides + // Check if they coincides ASSERT_EQ(out->DebugString(), clust_state_.DebugString()); } + +TEST(uni_norm_likelihood, data_addremove) { + // Instance + auto like = std::make_shared(); + + // Add new datum to likelihood + Eigen::VectorXd datum(1); + datum << 5.0; + like->add_datum(0, datum); + + // Check if cardinality is augmented + ASSERT_EQ(like->get_card(), 1); + + // Remove datum from likelihood + like->remove_datum(0, datum); + + // Check if cardinality is reduced + ASSERT_EQ(like->get_card(), 0); +} + +TEST(uni_norm_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + state_.set_mean(5); + state_.set_var(1); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 4.5, 5.1, 2.5; + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} From b0ed8917862601b2990942f3e2798a17aa56d9e3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 19:27:24 +0100 Subject: [PATCH 026/317] Improved API for Likelihood classes --- .../likelihoods/abstract_likelihood.h | 18 ++++++++++++++++-- src/hierarchies/likelihoods/base_likelihood.h | 15 +++++++-------- .../likelihoods/uni_norm_likelihood.h | 4 ++-- src/hierarchies/priors/abstract_prior_model.h | 5 ++++- test/likelihoods.cc | 17 ++++++++++------- 5 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 9e27a967e..ed052f562 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -18,6 +18,7 @@ class AbstractLikelihood { public: virtual ~AbstractLikelihood() = default; + // IMPLEMENTED in BaseLikelihood virtual std::shared_ptr clone() const = 0; double lpdf(const Eigen::RowVectorXd &datum, @@ -40,8 +41,18 @@ class AbstractLikelihood { virtual void set_state_from_proto( const google::protobuf::Message &state_) = 0; - virtual std::shared_ptr - get_state_proto() const = 0; + // IMPLEMENTED in BaseLikelihood + virtual void write_state_to_proto(google::protobuf::Message *out) const = 0; + + // IMPLEMENTED in BaseLikelihood + virtual void add_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; + + // IMPLEMENTED in BaseLikelihood + virtual void remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; void update_sum_stats(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate, bool add) { @@ -55,6 +66,9 @@ class AbstractLikelihood { virtual void clear_summary_statistics() = 0; protected: + virtual std::shared_ptr + get_state_proto() const = 0; + virtual double compute_lpdf(const Eigen::RowVectorXd &datum) const { if (is_dependent()) { throw std::runtime_error( diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index ccf505c2d..11042edc5 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -35,18 +35,17 @@ class BaseLikelihood : public AbstractLikelihood { std::set get_data_idx() const { return cluster_data_idx; } - void write_state_to_proto( - google::protobuf::Message *out) const; // override; + void write_state_to_proto(google::protobuf::Message *out) const override; State get_state() const { return state; } - void add_datum(const int id, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = - Eigen::RowVectorXd(0)); // override; + void add_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; - void remove_datum(const int id, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = - Eigen::RowVectorXd(0)); // override; + void remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; protected: void set_card(const int card_) { diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index e732b9058..b5c04aaba 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -20,11 +20,11 @@ class UniNormLikelihood bool is_multivariate() const override { return false; }; bool is_dependent() const override { return false; }; void set_state_from_proto(const google::protobuf::Message &state_) override; - std::shared_ptr get_state_proto() - const override; void clear_summary_statistics() override; protected: + std::shared_ptr get_state_proto() + const override; double compute_lpdf(const Eigen::RowVectorXd &datum) const override; void update_summary_statistics(const Eigen::RowVectorXd &datum, bool add) override; diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index e176d23bf..a8f8dc7a5 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -14,9 +14,12 @@ class AbstractPriorModel { public: virtual ~AbstractPriorModel() = default; + virtual std::shared_ptr clone() const = 0; + virtual double lpdf() = 0; - virtual void sample_prior() = 0; + // Da pensare, come restituisco lo stato? + virtual void sample() = 0; virtual void update_hypers( const std::vector &states) = 0; diff --git a/test/likelihoods.cc b/test/likelihoods.cc index 5a6447e4d..0058538f1 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -13,19 +13,22 @@ TEST(uni_norm_likelihood, state_setget) { // Instance auto like = std::make_shared(); - // Set state from proto + // Prepare buffers bayesmix::UniLSState state_; - bayesmix::AlgorithmState::ClusterState clust_state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state state_.set_mean(5.23); state_.set_var(1.02); - clust_state_.mutable_uni_ls_state()->CopyFrom(state_); - like->set_state_from_proto(clust_state_); + set_state_.mutable_uni_ls_state()->CopyFrom(state_); - // Get state proto - auto out = like->get_state_proto(); + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); // Check if they coincides - ASSERT_EQ(out->DebugString(), clust_state_.DebugString()); + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); } TEST(uni_norm_likelihood, data_addremove) { From 9bed6e1ffceb7311a570f32f9762d5cbc3566ef2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 19:29:05 +0100 Subject: [PATCH 027/317] Developing PriorModel Classes (ONGOING) --- src/hierarchies/priors/abstract_prior_model.h | 2 +- src/hierarchies/priors/base_prior_model.h | 81 +++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 src/hierarchies/priors/base_prior_model.h diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index a8f8dc7a5..1e7b96581 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -18,7 +18,7 @@ class AbstractPriorModel { virtual double lpdf() = 0; - // Da pensare, come restituisco lo stato? + // Da pensare, come restituisco lo stato? magari un pointer? virtual void sample() = 0; virtual void update_hypers( diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h new file mode 100644 index 000000000..5e1115147 --- /dev/null +++ b/src/hierarchies/priors/base_prior_model.h @@ -0,0 +1,81 @@ +#ifndef BAYESMIX_HIERARCHIES_BASE_PRIORMODEL_H_ +#define BAYESMIX_HIERARCHIES_BASE_PRIORMODEL_H_ + +#include + +#include +#include +#include +#include +#include + +#include "abstract_prior_model.h" +#include "algorithm_state.pb.h" +#include "hierarchy_id.pb.h" +#include "src/utils/rng.h" + +template +class BasePriorModel : public AbstractPriorModel { + public: + BasePriorModel() = default; + + ~BasePriorModel() = default; + + virtual std::shared_ptr clone() const override { + auto out = std::make_shared(static_cast(*this)); + return out; + } + + // sample method, che posso fare?? + + virtual google::protobuf::Message *get_mutable_prior() override { + if (prior == nullptr) { + create_empty_prior(); + } + return prior.get(); + } + + Hyperparams get_hypers() const { return *hypers; } + + void write_hypers_to_proto(google::protobuf::Message *out) const override; + + protected: + void check_prior_is_set() const { + if (prior == nullptr) { + throw std::invalid_argument("Hierarchy prior was not provided"); + } + } + + void create_empty_prior() { prior.reset(new Prior); } + + virtual std::shared_ptr + get_hypers_proto() const = 0; + + virtual void initialize_hypers() = 0; + + bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( + google::protobuf::Message *state_) const { + return google::protobuf::internal::down_cast< + bayesmix::AlgorithmState::HierarchyHypers *>(state_); + } + + const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( + const google::protobuf::Message &state_) const { + return google::protobuf::internal::down_cast< + const bayesmix::AlgorithmState::HierarchyHypers &>(state_); + } + + std::shared_ptr hypers; + std::shared_ptr prior; +}; + +template +void BasePriorModel::write_hypers_to_proto( + google::protobuf::Message *out) const { + std::shared_ptr hypers_ = + get_hypers_proto(); + auto *out_cast = downcast_hypers(out); + out_cast->CopyFrom(*hypers_.get()); +} + +#endif // BAYESMIX_HIERARCHIES_BASE_PRIORMODEL_H_ From f5008da6b4c9e48ce7dc694359d7b8c2a99ea6e6 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 12 Jan 2022 21:43:42 +0100 Subject: [PATCH 028/317] Improve API for PriorModel classes (ONGOING) --- src/hierarchies/priors/abstract_prior_model.h | 6 ++++++ src/hierarchies/priors/base_prior_model.h | 5 ----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 1e7b96581..4949eb564 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -32,6 +32,12 @@ class AbstractPriorModel { const google::protobuf::Message &state_) = 0; virtual void write_hypers_to_proto(google::protobuf::Message *out) const = 0; + + protected: + virtual std::shared_ptr + get_hypers_proto() const = 0; + + virtual void initialize_hypers() = 0; }; #endif diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 5e1115147..2513e2d2e 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -48,11 +48,6 @@ class BasePriorModel : public AbstractPriorModel { void create_empty_prior() { prior.reset(new Prior); } - virtual std::shared_ptr - get_hypers_proto() const = 0; - - virtual void initialize_hypers() = 0; - bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< From 017dc0b2a3d3f8aee6c52582dd69c2608c202b91 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 17:25:18 +0100 Subject: [PATCH 029/317] Add priors subdirectory to target bayesmix --- src/hierarchies/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index 567a6cd57..b354649d0 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -14,3 +14,4 @@ target_sources(bayesmix ) add_subdirectory(likelihoods) +add_subdirectory(priors) From b57da372b020eae559202e8820d9b405fa4d3003 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 17:26:13 +0100 Subject: [PATCH 030/317] Add states.h to target bayesmix --- src/hierarchies/likelihoods/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 216b0c849..3e5495073 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -2,6 +2,7 @@ target_sources(bayesmix PUBLIC abstract_likelihood.h base_likelihood.h + states.h uni_norm_likelihood.h uni_norm_likelihood.cc ) From 67453a915da5274be376bf98462b5e4ec92d5528 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 17:27:36 +0100 Subject: [PATCH 031/317] Minor changes --- .../likelihoods/abstract_likelihood.h | 20 +++++++++---------- src/hierarchies/likelihoods/base_likelihood.h | 8 +++++--- src/hierarchies/likelihoods/states.h | 16 +++++++++++---- .../likelihoods/uni_norm_likelihood.cc | 4 ++-- .../likelihoods/uni_norm_likelihood.h | 5 ++--- 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index ed052f562..eef452965 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -54,12 +54,13 @@ class AbstractLikelihood { const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; - void update_sum_stats(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate, bool add) { + void update_summary_statistics(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) { if (is_dependent()) { - return update_summary_statistics(datum, covariate, add); + return update_sum_stats(datum, covariate, add); } else { - return update_summary_statistics(datum, add); + return update_sum_stats(datum, add); } } @@ -88,8 +89,7 @@ class AbstractLikelihood { } } - virtual void update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) { + virtual void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) { if (is_dependent()) { throw std::runtime_error( "Cannot call this function from a dependent hierarchy"); @@ -98,9 +98,9 @@ class AbstractLikelihood { } } - virtual void update_summary_statistics(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate, - bool add) { + virtual void update_sum_stats(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) { if (!is_dependent()) { throw std::runtime_error( "Cannot call this function from a non-dependent hierarchy"); @@ -110,4 +110,4 @@ class AbstractLikelihood { } }; -#endif +#endif // BAYESMIX_HIERARCHIES_ABSTRACT_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 11042edc5..9f35b82cd 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -86,7 +86,8 @@ void BaseLikelihood::add_datum( assert(cluster_data_idx.find(id) == cluster_data_idx.end()); card += 1; log_card = std::log(card); - static_cast(this)->update_sum_stats(datum, covariate, true); + static_cast(this)->update_summary_statistics(datum, covariate, + true); cluster_data_idx.insert(id); } @@ -94,7 +95,8 @@ template void BaseLikelihood::remove_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) { - static_cast(this)->update_sum_stats(datum, covariate, false); + static_cast(this)->update_summary_statistics(datum, covariate, + false); set_card(card - 1); auto it = cluster_data_idx.find(id); assert(it != cluster_data_idx.end()); @@ -137,4 +139,4 @@ Eigen::VectorXd BaseLikelihood::lpdf_grid( return lpdf; } -#endif +#endif // BAYESMIX_HIERARCHIES_BASE_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states.h index 73a368706..e9d7b7327 100644 --- a/src/hierarchies/likelihoods/states.h +++ b/src/hierarchies/likelihoods/states.h @@ -5,21 +5,29 @@ namespace State { -struct UniLS { +class Base { + protected: + Base() = default; + + public: + virtual ~Base() = default; +}; + +struct UniLS : public Base { double mean, var; }; -struct MultiLS { +struct MultiLS : public Base { Eigen::VectorXd mean; Eigen::MatrixXd prec, prec_chol; double prec_logdet; }; -struct UniLinReg { +struct UniLinReg : public Base { Eigen::VectorXd regression_coeffs; double var; }; } // namespace State -#endif +#endif // BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index 9cdfb70aa..ac2539098 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -4,8 +4,8 @@ double UniNormLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); } -void UniNormLikelihood::update_summary_statistics( - const Eigen::RowVectorXd &datum, bool add) { +void UniNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + bool add) { if (add) { data_sum += datum(0); data_sum_squares += datum(0) * datum(0); diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index b5c04aaba..5607d8b62 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -26,11 +26,10 @@ class UniNormLikelihood std::shared_ptr get_state_proto() const override; double compute_lpdf(const Eigen::RowVectorXd &datum) const override; - void update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; double data_sum = 0; double data_sum_squares = 0; }; -#endif +#endif // BAYESMIX_HIERARCHIES_UNI_NORM_LIKELIHOOD_H_ From fa9c27388e8cf4f11515ad702020d062ba38c0a1 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 17:28:37 +0100 Subject: [PATCH 032/317] Add source files to target bayesmix --- src/hierarchies/priors/CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 src/hierarchies/priors/CMakeLists.txt diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt new file mode 100644 index 000000000..8547a6acf --- /dev/null +++ b/src/hierarchies/priors/CMakeLists.txt @@ -0,0 +1,8 @@ +target_sources(bayesmix + PUBLIC + abstract_prior_model.h + base_prior_model.h + hyperparams.h + nig_prior_model.h + nig_prior_model.cc +) From 53996d591f51a02a94423c37d5fed7e031784e71 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 17:29:52 +0100 Subject: [PATCH 033/317] Add NIGPriorModel class specification --- src/hierarchies/priors/nig_prior_model.cc | 196 ++++++++++++++++++++++ src/hierarchies/priors/nig_prior_model.h | 38 +++++ 2 files changed, 234 insertions(+) create mode 100644 src/hierarchies/priors/nig_prior_model.cc create mode 100644 src/hierarchies/priors/nig_prior_model.h diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc new file mode 100644 index 000000000..e7a1e5437 --- /dev/null +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -0,0 +1,196 @@ +#include "nig_prior_model.h" + +void NIGPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers.mean = prior->fixed_values().mean(); + hypers.var_scaling = prior->fixed_values().var_scaling(); + hypers.shape = prior->fixed_values().shape(); + hypers.scale = prior->fixed_values().scale(); + // Check validity + if (hypers.var_scaling <= 0) { + throw std::invalid_argument("Variance-scaling parameter must be > 0"); + } + if (hypers.shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers.scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } else if (prior->has_normal_mean_prior()) { + // Set initial values + hypers.mean = prior->normal_mean_prior().mean_prior().mean(); + hypers.var_scaling = prior->normal_mean_prior().var_scaling(); + hypers.shape = prior->normal_mean_prior().shape(); + hypers.scale = prior->normal_mean_prior().scale(); + // Check validity + if (hypers.var_scaling <= 0) { + throw std::invalid_argument("Variance-scaling parameter must be > 0"); + } + if (hypers.shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers.scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } else if (prior->has_ngg_prior()) { + // Get hyperparameters: + // for mu0 + double mu00 = prior->ngg_prior().mean_prior().mean(); + double sigma00 = prior->ngg_prior().mean_prior().var(); + // for lambda0 + double alpha00 = prior->ngg_prior().var_scaling_prior().shape(); + double beta00 = prior->ngg_prior().var_scaling_prior().rate(); + // for beta0 + double a00 = prior->ngg_prior().scale_prior().shape(); + double b00 = prior->ngg_prior().scale_prior().rate(); + // for alpha0 + double alpha0 = prior->ngg_prior().shape(); + // Check validity + if (sigma00 <= 0) { + throw std::invalid_argument("Variance parameter must be > 0"); + } + if (alpha00 <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (beta00 <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + if (a00 <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (b00 <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + if (alpha0 <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + // Set initial values + hypers.mean = mu00; + hypers.var_scaling = alpha00 / beta00; + hypers.shape = alpha0; + hypers.scale = a00 / b00; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +double NIGPriorModel::lpdf() { + if (prior->has_fixed_values()) { + return 0; + } else if (prior->has_normal_mean_prior()) { + double mu = prior->normal_mean_prior().mean_prior().mean(); + double var = prior->normal_mean_prior().mean_prior().var(); + return stan::math::normal_lpdf(hypers.mean, mu, sqrt(var)); + } else if (prior->has_ngg_prior()) { + // Set variables + double mu, var, shape, rate; + double target = 0; + + // Gaussian distribution on the mean + mu = prior->ngg_prior().mean_prior().mean(); + var = prior->ngg_prior().mean_prior().var(); + target += stan::math::normal_lpdf(hypers.mean, mu, sqrt(var)); + + // Gamma distribution on var_scaling + shape = prior->ngg_prior().var_scaling_prior().shape(); + rate = prior->ngg_prior().var_scaling_prior().rate(); + target += stan::math::gamma_lpdf(hypers.var_scaling, shape, rate); + + // Gamma distribution on scale + shape = prior->ngg_prior().scale_prior().shape(); + rate = prior->ngg_prior().scale_prior().rate(); + target += stan::math::gamma_lpdf(hypers.var_scaling, shape, rate); + + return target; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NIGPriorModel::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + + if (prior->has_fixed_values()) { + return; + } else if (prior->has_normal_mean_prior()) { + // Get hyperparameters + double mu00 = prior->normal_mean_prior().mean_prior().mean(); + double sig200 = prior->normal_mean_prior().mean_prior().var(); + double lambda0 = prior->normal_mean_prior().var_scaling(); + // Compute posterior hyperparameters + double prec = 0.0; + double num = 0.0; + for (auto &st : states) { + double mean = st.uni_ls_state().mean(); + double var = st.uni_ls_state().var(); + prec += 1 / var; + num += mean / var; + } + prec = 1 / sig200 + lambda0 * prec; + num = mu00 / sig200 + lambda0 * num; + double mu_n = num / prec; + double sig2_n = 1 / prec; + // Update hyperparameters with posterior random sampling + hypers.mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); + } else if (prior->has_ngg_prior()) { + // Get hyperparameters: + // for mu0 + double mu00 = prior->ngg_prior().mean_prior().mean(); + double sig200 = prior->ngg_prior().mean_prior().var(); + // for lambda0 + double alpha00 = prior->ngg_prior().var_scaling_prior().shape(); + double beta00 = prior->ngg_prior().var_scaling_prior().rate(); + // for tau0 + double a00 = prior->ngg_prior().scale_prior().shape(); + double b00 = prior->ngg_prior().scale_prior().rate(); + // Compute posterior hyperparameters + double b_n = 0.0; + double num = 0.0; + double beta_n = 0.0; + for (auto &st : states) { + double mean = st.uni_ls_state().mean(); + double var = st.uni_ls_state().var(); + b_n += 1 / var; + num += mean / var; + beta_n += (hypers.mean - mean) * (hypers.mean - mean) / var; + } + double var = hypers.var_scaling * b_n + 1 / sig200; + b_n += b00; + num = hypers.var_scaling * num + mu00 / sig200; + beta_n = beta00 + 0.5 * beta_n; + double sig_n = 1 / var; + double mu_n = num / var; + double alpha_n = alpha00 + 0.5 * states.size(); + double a_n = a00 + states.size() * hypers.shape; + // Update hyperparameters with posterior random Gibbs sampling + hypers.mean = stan::math::normal_rng(mu_n, sig_n, rng); + hypers.var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); + hypers.scale = stan::math::gamma_rng(a_n, b_n, rng); + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NIGPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).nnig_state(); + hypers.mean = hyperscast.mean(); + hypers.var_scaling = hyperscast.var_scaling(); + hypers.scale = hyperscast.scale(); + hypers.shape = hyperscast.shape(); +} + +std::shared_ptr +NIGPriorModel::get_hypers_proto() const { + bayesmix::NIGDistribution hypers_; + hypers_.set_mean(hypers.mean); + hypers_.set_var_scaling(hypers.var_scaling); + hypers_.set_shape(hypers.shape); + hypers_.set_scale(hypers.scale); + + auto out = std::make_shared(); + out->mutable_nnig_state()->CopyFrom(hypers_); + return out; +} diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h new file mode 100644 index 000000000..04a356bc8 --- /dev/null +++ b/src/hierarchies/priors/nig_prior_model.h @@ -0,0 +1,38 @@ +#ifndef BAYESMIX_HIERARCHIES_NIG_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_NIG_PRIOR_MODEL_H_ + +// #include + +#include +#include +#include +#include + +// #include "algorithm_state.pb.h" +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +class NIGPriorModel : public BasePriorModel { + public: + NIGPriorModel() = default; + ~NIGPriorModel() = default; + + void initialize_hypers() override; + + double lpdf() override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + protected: + std::shared_ptr get_hypers_proto() + const override; +}; + +#endif // BAYESMIX_HIERARCHIES_NIG_PRIOR_MODEL_H_ From 2efa726039212389f850b204e1274879fa9b3986 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 17:30:23 +0100 Subject: [PATCH 034/317] Minor changes --- src/hierarchies/priors/abstract_prior_model.h | 10 +++++----- src/hierarchies/priors/base_prior_model.h | 12 +++++------- src/hierarchies/priors/hyperparams.h | 2 +- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 4949eb564..1d44b6a9c 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -14,23 +14,23 @@ class AbstractPriorModel { public: virtual ~AbstractPriorModel() = default; + // IMPLEMENTED in BasePriorModel virtual std::shared_ptr clone() const = 0; virtual double lpdf() = 0; - // Da pensare, come restituisco lo stato? magari un pointer? - virtual void sample() = 0; + // Da pensare, come restituisco lo stato? magari un pointer? Oppure delego + // all'updater?? virtual void sample() = 0; virtual void update_hypers( const std::vector &states) = 0; - virtual void initialize_hypers() = 0; - virtual google::protobuf::Message *get_mutable_prior() = 0; virtual void set_hypers_from_proto( const google::protobuf::Message &state_) = 0; + // IMPLEMENTED in BasePriorModel virtual void write_hypers_to_proto(google::protobuf::Message *out) const = 0; protected: @@ -40,4 +40,4 @@ class AbstractPriorModel { virtual void initialize_hypers() = 0; }; -#endif +#endif // BAYESMIX_HIERARCHIES_ABSTRACT_PRIORMODEL_H_ diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 2513e2d2e..cbc6b3379 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -14,7 +14,7 @@ #include "hierarchy_id.pb.h" #include "src/utils/rng.h" -template +template class BasePriorModel : public AbstractPriorModel { public: BasePriorModel() = default; @@ -26,8 +26,6 @@ class BasePriorModel : public AbstractPriorModel { return out; } - // sample method, che posso fare?? - virtual google::protobuf::Message *get_mutable_prior() override { if (prior == nullptr) { create_empty_prior(); @@ -35,7 +33,7 @@ class BasePriorModel : public AbstractPriorModel { return prior.get(); } - Hyperparams get_hypers() const { return *hypers; } + HyperParams get_hypers() const { return *hypers; } void write_hypers_to_proto(google::protobuf::Message *out) const override; @@ -60,12 +58,12 @@ class BasePriorModel : public AbstractPriorModel { const bayesmix::AlgorithmState::HierarchyHypers &>(state_); } - std::shared_ptr hypers; + HyperParams hypers; // std::shared_ptr hypers; std::shared_ptr prior; }; -template -void BasePriorModel::write_hypers_to_proto( +template +void BasePriorModel::write_hypers_to_proto( google::protobuf::Message *out) const { std::shared_ptr hypers_ = get_hypers_proto(); diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h index aaddb6c4a..b338aa1fe 100644 --- a/src/hierarchies/priors/hyperparams.h +++ b/src/hierarchies/priors/hyperparams.h @@ -27,4 +27,4 @@ struct MNIG { } // namespace Hyperparams -#endif +#endif // BAYESMIX_HIERARCHIES_PRIORMODEL_HYPERPARAMS_H_ From f2dea8ecad5fadfbfa35642ec0d1676686334648 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 17:30:46 +0100 Subject: [PATCH 035/317] Changed test names --- test/likelihoods.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/likelihoods.cc b/test/likelihoods.cc index 0058538f1..d5af2a024 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -9,7 +9,7 @@ #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/utils/rng.h" -TEST(uni_norm_likelihood, state_setget) { +TEST(uni_norm_likelihood, set_get_state) { // Instance auto like = std::make_shared(); @@ -31,7 +31,7 @@ TEST(uni_norm_likelihood, state_setget) { ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); } -TEST(uni_norm_likelihood, data_addremove) { +TEST(uni_norm_likelihood, add_remove_data) { // Instance auto like = std::make_shared(); From 094037628076d2ffaf8d170a8a3456337997d623 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 17:31:10 +0100 Subject: [PATCH 036/317] Add test for Prior Models (ONGOING) --- test/CMakeLists.txt | 1 + test/prior_models.cc | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 test/prior_models.cc diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 62d34fde3..533330ddd 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -29,6 +29,7 @@ add_executable(test_bayesmix $ # rng.cc # logit_sb.cc likelihoods.cc + prior_models.cc ) target_include_directories(test_bayesmix PUBLIC ${INCLUDE_PATHS}) diff --git a/test/prior_models.cc b/test/prior_models.cc new file mode 100644 index 000000000..dec618160 --- /dev/null +++ b/test/prior_models.cc @@ -0,0 +1,35 @@ +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "hierarchy_prior.pb.h" +// #include "ls_state.pb.h" +#include "src/hierarchies/priors/nig_prior_model.h" +// #include "src/utils/rng.h" + +TEST(nig_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::NIGDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + // Prepare hypers + hypers_.set_mean(5.0); + hypers_.set_var_scaling(0.1); + hypers_.set_shape(4.0); + hypers_.set_scale(3.0); + set_state_.mutable_nnig_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} From 07fb6b293567e233149559b44ea1d8897174b34f Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 13 Jan 2022 23:33:03 +0100 Subject: [PATCH 037/317] Bug fixes --- src/hierarchies/priors/base_prior_model.h | 4 +- src/hierarchies/priors/nig_prior_model.cc | 76 +++++++++++------------ 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index cbc6b3379..62880ae64 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -58,8 +58,8 @@ class BasePriorModel : public AbstractPriorModel { const bayesmix::AlgorithmState::HierarchyHypers &>(state_); } - HyperParams hypers; // std::shared_ptr hypers; - std::shared_ptr prior; + std::shared_ptr hypers = std::make_shared(); + std::shared_ptr prior = std::make_shared(); }; template diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index e7a1e5437..875d247dd 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -3,34 +3,34 @@ void NIGPriorModel::initialize_hypers() { if (prior->has_fixed_values()) { // Set values - hypers.mean = prior->fixed_values().mean(); - hypers.var_scaling = prior->fixed_values().var_scaling(); - hypers.shape = prior->fixed_values().shape(); - hypers.scale = prior->fixed_values().scale(); + hypers->mean = prior->fixed_values().mean(); + hypers->var_scaling = prior->fixed_values().var_scaling(); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); // Check validity - if (hypers.var_scaling <= 0) { + if (hypers->var_scaling <= 0) { throw std::invalid_argument("Variance-scaling parameter must be > 0"); } - if (hypers.shape <= 0) { + if (hypers->shape <= 0) { throw std::invalid_argument("Shape parameter must be > 0"); } - if (hypers.scale <= 0) { + if (hypers->scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } } else if (prior->has_normal_mean_prior()) { // Set initial values - hypers.mean = prior->normal_mean_prior().mean_prior().mean(); - hypers.var_scaling = prior->normal_mean_prior().var_scaling(); - hypers.shape = prior->normal_mean_prior().shape(); - hypers.scale = prior->normal_mean_prior().scale(); + hypers->mean = prior->normal_mean_prior().mean_prior().mean(); + hypers->var_scaling = prior->normal_mean_prior().var_scaling(); + hypers->shape = prior->normal_mean_prior().shape(); + hypers->scale = prior->normal_mean_prior().scale(); // Check validity - if (hypers.var_scaling <= 0) { + if (hypers->var_scaling <= 0) { throw std::invalid_argument("Variance-scaling parameter must be > 0"); } - if (hypers.shape <= 0) { + if (hypers->shape <= 0) { throw std::invalid_argument("Shape parameter must be > 0"); } - if (hypers.scale <= 0) { + if (hypers->scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } } else if (prior->has_ngg_prior()) { @@ -66,10 +66,10 @@ void NIGPriorModel::initialize_hypers() { throw std::invalid_argument("Shape parameter must be > 0"); } // Set initial values - hypers.mean = mu00; - hypers.var_scaling = alpha00 / beta00; - hypers.shape = alpha0; - hypers.scale = a00 / b00; + hypers->mean = mu00; + hypers->var_scaling = alpha00 / beta00; + hypers->shape = alpha0; + hypers->scale = a00 / b00; } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } @@ -81,7 +81,7 @@ double NIGPriorModel::lpdf() { } else if (prior->has_normal_mean_prior()) { double mu = prior->normal_mean_prior().mean_prior().mean(); double var = prior->normal_mean_prior().mean_prior().var(); - return stan::math::normal_lpdf(hypers.mean, mu, sqrt(var)); + return stan::math::normal_lpdf(hypers->mean, mu, sqrt(var)); } else if (prior->has_ngg_prior()) { // Set variables double mu, var, shape, rate; @@ -90,17 +90,17 @@ double NIGPriorModel::lpdf() { // Gaussian distribution on the mean mu = prior->ngg_prior().mean_prior().mean(); var = prior->ngg_prior().mean_prior().var(); - target += stan::math::normal_lpdf(hypers.mean, mu, sqrt(var)); + target += stan::math::normal_lpdf(hypers->mean, mu, sqrt(var)); // Gamma distribution on var_scaling shape = prior->ngg_prior().var_scaling_prior().shape(); rate = prior->ngg_prior().var_scaling_prior().rate(); - target += stan::math::gamma_lpdf(hypers.var_scaling, shape, rate); + target += stan::math::gamma_lpdf(hypers->var_scaling, shape, rate); // Gamma distribution on scale shape = prior->ngg_prior().scale_prior().shape(); rate = prior->ngg_prior().scale_prior().rate(); - target += stan::math::gamma_lpdf(hypers.var_scaling, shape, rate); + target += stan::math::gamma_lpdf(hypers->var_scaling, shape, rate); return target; } else { @@ -133,7 +133,7 @@ void NIGPriorModel::update_hypers( double mu_n = num / prec; double sig2_n = 1 / prec; // Update hyperparameters with posterior random sampling - hypers.mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); + hypers->mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); } else if (prior->has_ngg_prior()) { // Get hyperparameters: // for mu0 @@ -154,20 +154,20 @@ void NIGPriorModel::update_hypers( double var = st.uni_ls_state().var(); b_n += 1 / var; num += mean / var; - beta_n += (hypers.mean - mean) * (hypers.mean - mean) / var; + beta_n += (hypers->mean - mean) * (hypers->mean - mean) / var; } - double var = hypers.var_scaling * b_n + 1 / sig200; + double var = hypers->var_scaling * b_n + 1 / sig200; b_n += b00; - num = hypers.var_scaling * num + mu00 / sig200; + num = hypers->var_scaling * num + mu00 / sig200; beta_n = beta00 + 0.5 * beta_n; double sig_n = 1 / var; double mu_n = num / var; double alpha_n = alpha00 + 0.5 * states.size(); - double a_n = a00 + states.size() * hypers.shape; + double a_n = a00 + states.size() * hypers->shape; // Update hyperparameters with posterior random Gibbs sampling - hypers.mean = stan::math::normal_rng(mu_n, sig_n, rng); - hypers.var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); - hypers.scale = stan::math::gamma_rng(a_n, b_n, rng); + hypers->mean = stan::math::normal_rng(mu_n, sig_n, rng); + hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); + hypers->scale = stan::math::gamma_rng(a_n, b_n, rng); } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } @@ -176,19 +176,19 @@ void NIGPriorModel::update_hypers( void NIGPriorModel::set_hypers_from_proto( const google::protobuf::Message &hypers_) { auto &hyperscast = downcast_hypers(hypers_).nnig_state(); - hypers.mean = hyperscast.mean(); - hypers.var_scaling = hyperscast.var_scaling(); - hypers.scale = hyperscast.scale(); - hypers.shape = hyperscast.shape(); + hypers->mean = hyperscast.mean(); + hypers->var_scaling = hyperscast.var_scaling(); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); } std::shared_ptr NIGPriorModel::get_hypers_proto() const { bayesmix::NIGDistribution hypers_; - hypers_.set_mean(hypers.mean); - hypers_.set_var_scaling(hypers.var_scaling); - hypers_.set_shape(hypers.shape); - hypers_.set_scale(hypers.scale); + hypers_.set_mean(hypers->mean); + hypers_.set_var_scaling(hypers->var_scaling); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); auto out = std::make_shared(); out->mutable_nnig_state()->CopyFrom(hypers_); From e3fa06d1ed685871b8db43970d6b6793e4eed200 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 14 Jan 2022 14:56:13 +0100 Subject: [PATCH 038/317] Add unit tests for NIGPriorModel --- test/prior_models.cc | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/prior_models.cc b/test/prior_models.cc index dec618160..67d86f81f 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -33,3 +33,38 @@ TEST(nig_prior_model, set_get_hypers) { // Check if they coincides ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); } + +// TODO: test for the other priors available +TEST(nig_prior_model, fixed_values_prior) { + bayesmix::NNIGPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var_scaling(0.1); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); + + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + std::vector> prior_models; + std::vector states; + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnig_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnig_state().DebugString()); + } +} From 90f00ae44fa84009141214df4152c0bae58b6f0a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 14 Jan 2022 14:56:32 +0100 Subject: [PATCH 039/317] Improve PriorModels API --- src/hierarchies/priors/base_prior_model.h | 5 +++++ src/hierarchies/priors/nig_prior_model.h | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 62880ae64..d9a4eaa51 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -37,6 +37,11 @@ class BasePriorModel : public AbstractPriorModel { void write_hypers_to_proto(google::protobuf::Message *out) const override; + void initialize() { + check_prior_is_set(); + initialize_hypers(); + } + protected: void check_prior_is_set() const { if (prior == nullptr) { diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index 04a356bc8..ca359980b 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -20,8 +20,6 @@ class NIGPriorModel : public BasePriorModel @@ -33,6 +31,8 @@ class NIGPriorModel : public BasePriorModel get_hypers_proto() const override; + + void initialize_hypers() override; }; #endif // BAYESMIX_HIERARCHIES_NIG_PRIOR_MODEL_H_ From 1f7d61a92d7497ef68b9458aa29cdcd04389cfbf Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 14 Jan 2022 18:31:57 +0100 Subject: [PATCH 040/317] Minor code changes --- src/hierarchies/priors/base_prior_model.h | 56 +++++++++++++++-------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index d9a4eaa51..a3a46728b 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -21,33 +21,18 @@ class BasePriorModel : public AbstractPriorModel { ~BasePriorModel() = default; - virtual std::shared_ptr clone() const override { - auto out = std::make_shared(static_cast(*this)); - return out; - } + virtual std::shared_ptr clone() const override; - virtual google::protobuf::Message *get_mutable_prior() override { - if (prior == nullptr) { - create_empty_prior(); - } - return prior.get(); - } + virtual google::protobuf::Message *get_mutable_prior() override; HyperParams get_hypers() const { return *hypers; } void write_hypers_to_proto(google::protobuf::Message *out) const override; - void initialize() { - check_prior_is_set(); - initialize_hypers(); - } + void initialize(); protected: - void check_prior_is_set() const { - if (prior == nullptr) { - throw std::invalid_argument("Hierarchy prior was not provided"); - } - } + void check_prior_is_set() const; void create_empty_prior() { prior.reset(new Prior); } @@ -64,9 +49,27 @@ class BasePriorModel : public AbstractPriorModel { } std::shared_ptr hypers = std::make_shared(); - std::shared_ptr prior = std::make_shared(); + std::shared_ptr prior; }; +// Methods Definitions + +template +std::shared_ptr +BasePriorModel::clone() const { + auto out = std::make_shared(static_cast(*this)); + return out; +} + +template +google::protobuf::Message * +BasePriorModel::get_mutable_prior() { + if (prior == nullptr) { + create_empty_prior(); + } + return prior.get(); +} + template void BasePriorModel::write_hypers_to_proto( google::protobuf::Message *out) const { @@ -76,4 +79,17 @@ void BasePriorModel::write_hypers_to_proto( out_cast->CopyFrom(*hypers_.get()); } +template +void BasePriorModel::initialize() { + check_prior_is_set(); + initialize_hypers(); +} + +template +void BasePriorModel::check_prior_is_set() const { + if (prior == nullptr) { + throw std::invalid_argument("Hierarchy prior was not provided"); + } +} + #endif // BAYESMIX_HIERARCHIES_BASE_PRIORMODEL_H_ From cb14d176419934c9beac9a2bcdbbb62cf5fbe99d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 14 Jan 2022 18:32:42 +0100 Subject: [PATCH 041/317] Add NIGPriorModel test with Normal Mean Prior --- test/prior_models.cc | 45 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/test/prior_models.cc b/test/prior_models.cc index 67d86f81f..0de9d9a8a 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -34,22 +34,24 @@ TEST(nig_prior_model, set_get_hypers) { ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); } -// TODO: test for the other priors available TEST(nig_prior_model, fixed_values_prior) { + // Prepare buffers bayesmix::NNIGPrior prior; bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior prior.mutable_fixed_values()->set_mean(5.0); prior.mutable_fixed_values()->set_var_scaling(0.1); prior.mutable_fixed_values()->set_shape(2.0); prior.mutable_fixed_values()->set_scale(2.0); + // Initialize prior model auto prior_model = std::make_shared(); prior_model->get_mutable_prior()->CopyFrom(prior); prior_model->initialize(); - std::vector> prior_models; - std::vector states; - // Check equality before update prior_models.push_back(prior_model); for (size_t i = 1; i < 4; i++) { @@ -68,3 +70,38 @@ TEST(nig_prior_model, fixed_values_prior) { prior_out.nnig_state().DebugString()); } } + +TEST(nig_prior_model, normal_mean_prior) { + // Prepare buffers + bayesmix::NNIGPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + + // Set Normal prior on the mean + double mu00 = 0.5; + prior.mutable_normal_mean_prior()->mutable_mean_prior()->set_mean(mu00); + prior.mutable_normal_mean_prior()->mutable_mean_prior()->set_var(1.02); + prior.mutable_normal_mean_prior()->set_var_scaling(0.1); + prior.mutable_normal_mean_prior()->set_shape(2.0); + prior.mutable_normal_mean_prior()->set_scale(2.0); + + // Prepare some fictional states + std::vector states(4); + for (int i = 0; i < states.size(); i++) { + double mean = 9.0 + i; + states[i].mutable_uni_ls_state()->set_mean(mean); + states[i].mutable_uni_ls_state()->set_var(1.0); + } + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Update hypers in light of current states + prior_model->update_hypers(states); + prior_model->write_hypers_to_proto(&prior_out); + double mean_out = prior_out.nnig_state().mean(); + + // Check + ASSERT_GT(mean_out, mu00); +} From 58c1c9792da32880e0fd3477d82134c4845e2262 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:24:37 +0100 Subject: [PATCH 042/317] Moved old hierarchy pattern to .old --- src/hierarchies/.old/base_hierarchy.h | 321 +++++++++++++++ src/hierarchies/.old/conjugate_hierarchy.h | 206 ++++++++++ src/hierarchies/.old/lin_reg_uni_hierarchy.cc | 166 ++++++++ src/hierarchies/.old/lin_reg_uni_hierarchy.h | 144 +++++++ src/hierarchies/.old/nnig_hierarchy.cc | 267 +++++++++++++ src/hierarchies/.old/nnig_hierarchy.h | 122 ++++++ src/hierarchies/.old/nnw_hierarchy.cc | 373 ++++++++++++++++++ src/hierarchies/.old/nnw_hierarchy.h | 168 ++++++++ src/hierarchies/.old/nnxig_hierarchy.cc | 152 +++++++ src/hierarchies/.old/nnxig_hierarchy.h | 120 ++++++ 10 files changed, 2039 insertions(+) create mode 100644 src/hierarchies/.old/base_hierarchy.h create mode 100644 src/hierarchies/.old/conjugate_hierarchy.h create mode 100644 src/hierarchies/.old/lin_reg_uni_hierarchy.cc create mode 100644 src/hierarchies/.old/lin_reg_uni_hierarchy.h create mode 100644 src/hierarchies/.old/nnig_hierarchy.cc create mode 100644 src/hierarchies/.old/nnig_hierarchy.h create mode 100644 src/hierarchies/.old/nnw_hierarchy.cc create mode 100644 src/hierarchies/.old/nnw_hierarchy.h create mode 100644 src/hierarchies/.old/nnxig_hierarchy.cc create mode 100644 src/hierarchies/.old/nnxig_hierarchy.h diff --git a/src/hierarchies/.old/base_hierarchy.h b/src/hierarchies/.old/base_hierarchy.h new file mode 100644 index 000000000..e97c92080 --- /dev/null +++ b/src/hierarchies/.old/base_hierarchy.h @@ -0,0 +1,321 @@ +#ifndef BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ + +#include + +#include +#include +#include +#include +#include + +#include "abstract_hierarchy.h" +#include "algorithm_state.pb.h" +#include "hierarchy_id.pb.h" +#include "src/utils/rng.h" + +//! Base template class for a hierarchy object. + +//! This class is a templatized version of, and derived from, the +//! `AbstractHierarchy` class, and the second stage of the curiously recurring +//! template pattern for `Hierarchy` objects (please see the docs of the parent +//! class for further information). It includes class members and some more +//! functions which could not be implemented in the non-templatized abstract +//! class. +//! See, for instance, `ConjugateHierarchy` and `NNIGHierarchy` to better +//! understand the CRTP patterns. + +//! @tparam Derived Name of the implemented derived class +//! @tparam State Class name of the container for state values +//! @tparam Hyperparams Class name of the container for hyperprior parameters +//! @tparam Prior Class name of the container for prior parameters + +template +class BaseHierarchy : public AbstractHierarchy { + public: + BaseHierarchy() = default; + ~BaseHierarchy() = default; + + //! Returns an independent, data-less copy of this object + virtual std::shared_ptr clone() const override { + auto out = std::make_shared(static_cast(*this)); + out->clear_data(); + out->clear_summary_statistics(); + return out; + } + + //! Evaluates the log-likelihood of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + virtual Eigen::VectorXd like_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, + 0)) const override; + + //! Generates new state values from the centering prior distribution + void sample_prior() override { + state = static_cast(this)->draw(*hypers); + } + + //! Overloaded version of sample_full_cond(bool), mainly used for debugging + virtual void sample_full_cond( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override; + + //! Returns the current cardinality of the cluster + int get_card() const override { return card; } + + //! Returns the logarithm of the current cardinality of the cluster + double get_log_card() const override { return log_card; } + + //! Returns the indexes of data points belonging to this cluster + std::set get_data_idx() const override { return cluster_data_idx; } + + //! Returns a pointer to the Protobuf message of the prior of this cluster + virtual google::protobuf::Message *get_mutable_prior() override { + if (prior == nullptr) { + create_empty_prior(); + } + return prior.get(); + } + + //! Writes current state to a Protobuf message by pointer + void write_state_to_proto(google::protobuf::Message *out) const override; + + //! Writes current values of the hyperparameters to a Protobuf message by + //! pointer + void write_hypers_to_proto(google::protobuf::Message *out) const override; + + //! Returns the struct of the current state + State get_state() const { return state; } + + //! Returns the struct of the current prior hyperparameters + Hyperparams get_hypers() const { return *hypers; } + + //! Returns the struct of the current posterior hyperparameters + Hyperparams get_posterior_hypers() const { return posterior_hypers; } + + //! Adds a datum and its index to the hierarchy + void add_datum( + const int id, const Eigen::RowVectorXd &datum, + const bool update_params = false, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + + //! Removes a datum and its index from the hierarchy + void remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const bool update_params = false, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + + //! Main function that initializes members to appropriate values + void initialize() override { + hypers = std::make_shared(); + check_prior_is_set(); + initialize_hypers(); + initialize_state(); + posterior_hypers = *hypers; + clear_data(); + clear_summary_statistics(); + } + + protected: + //! Raises an error if the prior pointer is not initialized + void check_prior_is_set() const { + if (prior == nullptr) { + throw std::invalid_argument("Hierarchy prior was not provided"); + } + } + + //! Re-initializes the prior of the hierarchy to a newly created object + void create_empty_prior() { prior.reset(new Prior); } + + //! Sets the cardinality of the cluster + void set_card(const int card_) { + card = card_; + log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); + } + + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type + virtual std::shared_ptr + get_state_proto() const = 0; + + //! Initializes state parameters to appropriate values + virtual void initialize_state() = 0; + + //! Writes current value of hyperparameters to a Protobuf message and + //! return a shared_ptr. + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::HierarchyHypers message by adding the appropriate type + virtual std::shared_ptr + get_hypers_proto() const = 0; + + //! Initializes hierarchy hyperparameters to appropriate values + virtual void initialize_hypers() = 0; + + //! Resets cardinality and indexes of data in this cluster + void clear_data() { + set_card(0); + cluster_data_idx = std::set(); + } + + virtual void clear_summary_statistics() = 0; + + //! Down-casts the given generic proto message to a ClusterState proto + bayesmix::AlgorithmState::ClusterState *downcast_state( + google::protobuf::Message *state_) const { + return google::protobuf::internal::down_cast< + bayesmix::AlgorithmState::ClusterState *>(state_); + } + + //! Down-casts the given generic proto message to a ClusterState proto + const bayesmix::AlgorithmState::ClusterState &downcast_state( + const google::protobuf::Message &state_) const { + return google::protobuf::internal::down_cast< + const bayesmix::AlgorithmState::ClusterState &>(state_); + } + + //! Down-casts the given generic proto message to a HierarchyHypers proto + bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( + google::protobuf::Message *state_) const { + return google::protobuf::internal::down_cast< + bayesmix::AlgorithmState::HierarchyHypers *>(state_); + } + + //! Down-casts the given generic proto message to a HierarchyHypers proto + const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( + const google::protobuf::Message &state_) const { + return google::protobuf::internal::down_cast< + const bayesmix::AlgorithmState::HierarchyHypers &>(state_); + } + + //! Container for state values + State state; + + //! Container for prior hyperparameters values + std::shared_ptr hypers; + + //! Container for posterior hyperparameters values + Hyperparams posterior_hypers; + + //! Pointer to a Protobuf prior object for this class + std::shared_ptr prior; + + //! Set of indexes of data points belonging to this cluster + std::set cluster_data_idx; + + //! Current cardinality of this cluster + int card = 0; + + //! Logarithm of current cardinality of this cluster + double log_card = stan::math::NEGATIVE_INFTY; +}; + +template +void BaseHierarchy::add_datum( + const int id, const Eigen::RowVectorXd &datum, + const bool update_params /*= false*/, + const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) { + assert(cluster_data_idx.find(id) == cluster_data_idx.end()); + card += 1; + log_card = std::log(card); + static_cast(this)->update_ss(datum, covariate, true); + cluster_data_idx.insert(id); + if (update_params) { + static_cast(this)->save_posterior_hypers(); + } +} + +template +void BaseHierarchy::remove_datum( + const int id, const Eigen::RowVectorXd &datum, + const bool update_params /*= false*/, + const Eigen::RowVectorXd &covariate /* = Eigen::RowVectorXd(0)*/) { + static_cast(this)->update_ss(datum, covariate, false); + set_card(card - 1); + auto it = cluster_data_idx.find(id); + assert(it != cluster_data_idx.end()); + cluster_data_idx.erase(it); + if (update_params) { + static_cast(this)->save_posterior_hypers(); + } +} + +template +void BaseHierarchy::write_state_to_proto( + google::protobuf::Message *out) const { + std::shared_ptr state_ = + get_state_proto(); + auto *out_cast = downcast_state(out); + out_cast->CopyFrom(*state_.get()); + out_cast->set_cardinality(card); +} + +template +void BaseHierarchy::write_hypers_to_proto( + google::protobuf::Message *out) const { + std::shared_ptr hypers_ = + get_hypers_proto(); + auto *out_cast = downcast_hypers(out); + out_cast->CopyFrom(*hypers_.get()); +} + +template +Eigen::VectorXd +BaseHierarchy::like_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { + Eigen::VectorXd lpdf(data.rows()); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->get_like_lpdf( + data.row(i), Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->get_like_lpdf( + data.row(i), covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->get_like_lpdf( + data.row(i), covariates.row(i)); + } + } + return lpdf; +} + +template +void BaseHierarchy::sample_full_cond( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { + clear_data(); + clear_summary_statistics(); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + static_cast(this)->add_datum(i, data.row(i), false, + Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + static_cast(this)->add_datum(i, data.row(i), false, + covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + static_cast(this)->add_datum(i, data.row(i), false, + covariates.row(i)); + } + } + static_cast(this)->sample_full_cond(true); +} + +#endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/.old/conjugate_hierarchy.h b/src/hierarchies/.old/conjugate_hierarchy.h new file mode 100644 index 000000000..3a7350c98 --- /dev/null +++ b/src/hierarchies/.old/conjugate_hierarchy.h @@ -0,0 +1,206 @@ +#ifndef BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ + +#include "base_hierarchy.h" + +//! Template base class for conjugate hierarchy objects. + +//! This class acts as the base class for conjugate models, i.e. ones for which +//! both the prior and posterior distribution have the same form +//! (non-conjugate hierarchies should instead inherit directly from +//! `BaseHierarchy`). This also means that the marginal distribution for the +//! data is available in closed form. For this reason, each class deriving from +//! this one must have a free method with one of the following signatures, +//! based on whether it depends on covariates or not: +//! double marg_lpdf( +//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, +//! const Eigen::RowVectorXd &covariate) const; +//! or +//! double marg_lpdf( +//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum) const; +//! This returns the evaluation of the marginal distribution on the given data +//! point (and covariate, if any), conditioned on the provided `Hyperparams` +//! object. The latter may contain either prior or posterior values for +//! hyperparameters, depending on where this function is called within the +//! library. +//! For more information, please refer to parent classes `AbstractHierarchy` +//! and `BaseHierarchy`. + +template +class ConjugateHierarchy + : public BaseHierarchy { + public: + using BaseHierarchy::hypers; + using BaseHierarchy::posterior_hypers; + using BaseHierarchy::state; + + ConjugateHierarchy() = default; + ~ConjugateHierarchy() = default; + + //! Public wrapper for `marg_lpdf()` methods + virtual double get_marg_lpdf( + const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const; + + //! Evaluates the log-prior predictive distribution of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate (Optional) covariate vector associated to datum + //! @return The evaluation of the lpdf + double prior_pred_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)) const override { + return get_marg_lpdf(*hypers, datum, covariate); + } + + //! Evaluates the log-conditional predictive distr. of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate (Optional) covariate vector associated to datum + //! @return The evaluation of the lpdf + double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)) const override { + return get_marg_lpdf(posterior_hypers, datum, covariate); + } + + //! Evaluates the log-prior predictive distr. of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + virtual Eigen::VectorXd prior_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, + 0)) const override; + + //! Evaluates the log-prior predictive distr. of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + virtual Eigen::VectorXd conditional_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, + 0)) const override; + + //! Generates new state values from the centering posterior distribution + //! @param update_params Save posterior hypers after the computation? + void sample_full_cond(bool update_params = true) override { + if (this->card == 0) { + // No posterior update possible + static_cast(this)->sample_prior(); + } else { + Hyperparams params = + update_params + ? static_cast(this)->compute_posterior_hypers() + : posterior_hypers; + state = static_cast(this)->draw(params); + } + } + + //! Saves posterior hyperparameters to the corresponding class member + void save_posterior_hypers() { + posterior_hypers = + static_cast(this)->compute_posterior_hypers(); + } + + //! Returns whether the hierarchy represents a conjugate model or not + bool is_conjugate() const override { return true; } + + protected: + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf + virtual double marg_lpdf(const Hyperparams ¶ms, + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + if (!this->is_dependent()) { + throw std::runtime_error( + "Cannot call this function from a non-dependent hierarchy"); + } else { + throw std::runtime_error("Not implemented"); + } + } + + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + virtual double marg_lpdf(const Hyperparams ¶ms, + const Eigen::RowVectorXd &datum) const { + if (this->is_dependent()) { + throw std::runtime_error( + "Cannot call this function from a dependent hierarchy"); + } else { + throw std::runtime_error("Not implemented"); + } + } +}; + +template +double ConjugateHierarchy::get_marg_lpdf( + const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { + if (this->is_dependent()) { + return marg_lpdf(params, datum, covariate); + } else { + return marg_lpdf(params, datum); + } +} + +template +Eigen::VectorXd +ConjugateHierarchy::prior_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { + Eigen::VectorXd lpdf(data.rows()); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), covariates.row(i)); + } + } + return lpdf; +} + +template +Eigen::VectorXd ConjugateHierarchy:: + conditional_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { + Eigen::VectorXd lpdf(data.rows()); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->conditional_pred_lpdf( + data.row(i), Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->conditional_pred_lpdf( + data.row(i), covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->conditional_pred_lpdf( + data.row(i), covariates.row(i)); + } + } + return lpdf; +} + +#endif diff --git a/src/hierarchies/.old/lin_reg_uni_hierarchy.cc b/src/hierarchies/.old/lin_reg_uni_hierarchy.cc new file mode 100644 index 000000000..d48f6ada8 --- /dev/null +++ b/src/hierarchies/.old/lin_reg_uni_hierarchy.cc @@ -0,0 +1,166 @@ +#include "lin_reg_uni_hierarchy.h" + +#include +#include +#include + +#include "src/utils/eigen_utils.h" +#include "src/utils/proto_utils.h" +#include "src/utils/rng.h" + +double LinRegUniHierarchy::like_lpdf( + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + return stan::math::normal_lpdf( + datum(0), state.regression_coeffs.dot(covariate), sqrt(state.var)); +} + +double LinRegUniHierarchy::marg_lpdf( + const LinRegUni::Hyperparams ¶ms, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + double sig_n = sqrt( + (1 + (covariate * params.var_scaling_inv * covariate.transpose())(0)) * + params.scale / params.shape); + return stan::math::student_t_lpdf(datum(0), 2 * params.shape, + covariate.dot(params.mean), sig_n); +} + +void LinRegUniHierarchy::initialize_state() { + state.regression_coeffs = hypers->mean; + state.var = hypers->scale / (hypers->shape + 1); +} + +void LinRegUniHierarchy::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); + dim = hypers->mean.size(); + hypers->var_scaling = + bayesmix::to_eigen(prior->fixed_values().var_scaling()); + hypers->var_scaling_inv = stan::math::inverse_spd(hypers->var_scaling); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); + // Check validity + if (dim != hypers->var_scaling.rows()) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + bayesmix::check_spd(hypers->var_scaling); + if (hypers->shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers->scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } + + else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void LinRegUniHierarchy::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + if (prior->has_fixed_values()) { + return; + } + + else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +LinRegUni::State LinRegUniHierarchy::draw( + const LinRegUni::Hyperparams ¶ms) { + auto &rng = bayesmix::Rng::Instance().get(); + LinRegUni::State out; + out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); + out.regression_coeffs = stan::math::multi_normal_prec_rng( + params.mean, params.var_scaling / out.var, rng); + return out; +} + +void LinRegUniHierarchy::update_summary_statistics( + const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate, + bool add) { + if (add) { + data_sum_squares += datum(0) * datum(0); + covar_sum_squares += covariate.transpose() * covariate; + mixed_prod += datum(0) * covariate.transpose(); + } else { + data_sum_squares -= datum(0) * datum(0); + covar_sum_squares -= covariate.transpose() * covariate; + mixed_prod -= datum(0) * covariate.transpose(); + } +} + +void LinRegUniHierarchy::clear_summary_statistics() { + mixed_prod = Eigen::VectorXd::Zero(dim); + data_sum_squares = 0.0; + covar_sum_squares = Eigen::MatrixXd::Zero(dim, dim); +} + +LinRegUni::Hyperparams LinRegUniHierarchy::compute_posterior_hypers() const { + if (card == 0) { // no update possible + return *hypers; + } + // Compute posterior hyperparameters + LinRegUni::Hyperparams post_params; + post_params.var_scaling = covar_sum_squares + hypers->var_scaling; + auto llt = post_params.var_scaling.llt(); + post_params.var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, dim)); + post_params.mean = + llt.solve(mixed_prod + hypers->var_scaling * hypers->mean); + post_params.shape = hypers->shape + 0.5 * card; + post_params.scale = + hypers->scale + + 0.5 * (data_sum_squares + + hypers->mean.transpose() * hypers->var_scaling * hypers->mean - + post_params.mean.transpose() * post_params.var_scaling * + post_params.mean); + return post_params; +} + +void LinRegUniHierarchy::set_state_from_proto( + const google::protobuf::Message &state_) { + auto &statecast = downcast_state(state_); + state.regression_coeffs = + bayesmix::to_eigen(statecast.lin_reg_uni_ls_state().regression_coeffs()); + state.var = statecast.lin_reg_uni_ls_state().var(); + set_card(statecast.cardinality()); +} + +std::shared_ptr +LinRegUniHierarchy::get_state_proto() const { + bayesmix::LinRegUniLSState state_; + bayesmix::to_proto(state.regression_coeffs, + state_.mutable_regression_coeffs()); + state_.set_var(state.var); + + auto out = std::make_shared(); + out->mutable_lin_reg_uni_ls_state()->CopyFrom(state_); + return out; +} + +void LinRegUniHierarchy::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).lin_reg_uni_state(); + hypers->mean = bayesmix::to_eigen(hyperscast.mean()); + hypers->var_scaling = bayesmix::to_eigen(hyperscast.var_scaling()); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); +} + +std::shared_ptr +LinRegUniHierarchy::get_hypers_proto() const { + bayesmix::MultiNormalIGDistribution hypers_; + bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); + bayesmix::to_proto(hypers->var_scaling, hypers_.mutable_var_scaling()); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); + + auto out = std::make_shared(); + out->mutable_lin_reg_uni_state()->CopyFrom(hypers_); + return out; +} diff --git a/src/hierarchies/.old/lin_reg_uni_hierarchy.h b/src/hierarchies/.old/lin_reg_uni_hierarchy.h new file mode 100644 index 000000000..7789a6065 --- /dev/null +++ b/src/hierarchies/.old/lin_reg_uni_hierarchy.h @@ -0,0 +1,144 @@ +#ifndef BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "conjugate_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "hierarchy_prior.pb.h" + +//! Linear regression hierarchy for univariate data. + +//! This class implements a dependent hierarchy which represents the classical +//! univariate Bayesian linear regression model, i.e.: +//! y_i | \beta, x_i, \sigma^2 \sim N(\beta^T x_i, sigma^2) +//! \beta | \sigma^2 \sim N(\mu, sigma^2 Lambda^{-1}) +//! \sigma^2 \sim InvGamma(a, b) +//! +//! The state consists of the `regression_coeffs` \beta, and the `var` sigma^2. +//! Lambda is called the variance-scaling factor. For more information, please +//! refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and +//! `ConjugateHierarchy`. + +namespace LinRegUni { +//! Custom container for State values +struct State { + Eigen::VectorXd regression_coeffs; + double var; +}; + +//! Custom container for Hyperparameters values +struct Hyperparams { + Eigen::VectorXd mean; + Eigen::MatrixXd var_scaling; + Eigen::MatrixXd var_scaling_inv; + double shape; + double scale; +}; +} // namespace LinRegUni + +class LinRegUniHierarchy + : public ConjugateHierarchy { + public: + LinRegUniHierarchy() = default; + ~LinRegUniHierarchy() = default; + + //! Updates hyperparameter values given a vector of cluster states + void update_hypers(const std::vector + &states) override; + + //! Updates state values using the given (prior or posterior) hyperparameters + LinRegUni::State draw(const LinRegUni::Hyperparams ¶ms); + + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param covariate Covariate vector associated to datum + //! @param add Whether the datum is being added or removed + void update_summary_statistics(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) override; + + //! Resets summary statistics for this cluster + void clear_summary_statistics() override; + + //! Returns the Protobuf ID associated to this class + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::LinRegUni; + } + + //! Read and set state values from a given Protobuf message + void set_state_from_proto(const google::protobuf::Message &state_) override; + + //! Read and set hyperparameter values from a given Protobuf message + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type + std::shared_ptr get_state_proto() + const override; + + //! Writes current value of hyperparameters to a Protobuf message and + //! return a shared_ptr. + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::HierarchyHypers message by adding the appropriate type + std::shared_ptr get_hypers_proto() + const override; + + //! Returns the dimension of the coefficients vector + unsigned int get_dim() const { return dim; } + + //! Computes and return posterior hypers given data currently in this cluster + LinRegUni::Hyperparams compute_posterior_hypers() const; + + //! Returns whether the hierarchy models multivariate data or not + bool is_multivariate() const override { return false; } + + //! Returns whether the hierarchy depends on covariate values or not + bool is_dependent() const override { return true; } + + protected: + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf + double like_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const override; + + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf + double marg_lpdf(const LinRegUni::Hyperparams ¶ms, + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const override; + + //! Initializes state parameters to appropriate values + void initialize_state() override; + + //! Initializes hierarchy hyperparameters to appropriate values + void initialize_hypers() override; + + //! Dimension of the coefficients vector + unsigned int dim; + + //! Represents pieces of y^t y + double data_sum_squares; + + //! Represents pieces of X^T X + Eigen::MatrixXd covar_sum_squares; + + //! Represents pieces of X^t y + Eigen::VectorXd mixed_prod; +}; + +#endif // BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ diff --git a/src/hierarchies/.old/nnig_hierarchy.cc b/src/hierarchies/.old/nnig_hierarchy.cc new file mode 100644 index 000000000..c2b055178 --- /dev/null +++ b/src/hierarchies/.old/nnig_hierarchy.cc @@ -0,0 +1,267 @@ +#include "nnig_hierarchy.h" + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "hierarchy_prior.pb.h" +#include "ls_state.pb.h" +#include "src/utils/rng.h" + +double NNIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); +} + +double NNIGHierarchy::marg_lpdf(const NNIG::Hyperparams ¶ms, + const Eigen::RowVectorXd &datum) const { + double sig_n = sqrt(params.scale * (params.var_scaling + 1) / + (params.shape * params.var_scaling)); + return stan::math::student_t_lpdf(datum(0), 2 * params.shape, params.mean, + sig_n); +} + +void NNIGHierarchy::initialize_state() { + state.mean = hypers->mean; + state.var = hypers->scale / (hypers->shape + 1); +} + +void NNIGHierarchy::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mean = prior->fixed_values().mean(); + hypers->var_scaling = prior->fixed_values().var_scaling(); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); + // Check validity + if (hypers->var_scaling <= 0) { + throw std::invalid_argument("Variance-scaling parameter must be > 0"); + } + if (hypers->shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers->scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } + + else if (prior->has_normal_mean_prior()) { + // Set initial values + hypers->mean = prior->normal_mean_prior().mean_prior().mean(); + hypers->var_scaling = prior->normal_mean_prior().var_scaling(); + hypers->shape = prior->normal_mean_prior().shape(); + hypers->scale = prior->normal_mean_prior().scale(); + // Check validity + if (hypers->var_scaling <= 0) { + throw std::invalid_argument("Variance-scaling parameter must be > 0"); + } + if (hypers->shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers->scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } + + else if (prior->has_ngg_prior()) { + // Get hyperparameters: + // for mu0 + double mu00 = prior->ngg_prior().mean_prior().mean(); + double sigma00 = prior->ngg_prior().mean_prior().var(); + // for lambda0 + double alpha00 = prior->ngg_prior().var_scaling_prior().shape(); + double beta00 = prior->ngg_prior().var_scaling_prior().rate(); + // for beta0 + double a00 = prior->ngg_prior().scale_prior().shape(); + double b00 = prior->ngg_prior().scale_prior().rate(); + // for alpha0 + double alpha0 = prior->ngg_prior().shape(); + // Check validity + if (sigma00 <= 0) { + throw std::invalid_argument("Variance parameter must be > 0"); + } + if (alpha00 <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (beta00 <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + if (a00 <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (b00 <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + if (alpha0 <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + // Set initial values + hypers->mean = mu00; + hypers->var_scaling = alpha00 / beta00; + hypers->shape = alpha0; + hypers->scale = a00 / b00; + } + + else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NNIGHierarchy::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + + if (prior->has_fixed_values()) { + return; + } + + else if (prior->has_normal_mean_prior()) { + // Get hyperparameters + double mu00 = prior->normal_mean_prior().mean_prior().mean(); + double sig200 = prior->normal_mean_prior().mean_prior().var(); + double lambda0 = prior->normal_mean_prior().var_scaling(); + // Compute posterior hyperparameters + double prec = 0.0; + double num = 0.0; + for (auto &st : states) { + double mean = st.uni_ls_state().mean(); + double var = st.uni_ls_state().var(); + prec += 1 / var; + num += mean / var; + } + prec = 1 / sig200 + lambda0 * prec; + num = mu00 / sig200 + lambda0 * num; + double mu_n = num / prec; + double sig2_n = 1 / prec; + // Update hyperparameters with posterior random sampling + hypers->mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); + } + + else if (prior->has_ngg_prior()) { + // Get hyperparameters: + // for mu0 + double mu00 = prior->ngg_prior().mean_prior().mean(); + double sig200 = prior->ngg_prior().mean_prior().var(); + // for lambda0 + double alpha00 = prior->ngg_prior().var_scaling_prior().shape(); + double beta00 = prior->ngg_prior().var_scaling_prior().rate(); + // for tau0 + double a00 = prior->ngg_prior().scale_prior().shape(); + double b00 = prior->ngg_prior().scale_prior().rate(); + // Compute posterior hyperparameters + double b_n = 0.0; + double num = 0.0; + double beta_n = 0.0; + for (auto &st : states) { + double mean = st.uni_ls_state().mean(); + double var = st.uni_ls_state().var(); + b_n += 1 / var; + num += mean / var; + beta_n += (hypers->mean - mean) * (hypers->mean - mean) / var; + } + double var = hypers->var_scaling * b_n + 1 / sig200; + b_n += b00; + num = hypers->var_scaling * num + mu00 / sig200; + beta_n = beta00 + 0.5 * beta_n; + double sig_n = 1 / var; + double mu_n = num / var; + double alpha_n = alpha00 + 0.5 * states.size(); + double a_n = a00 + states.size() * hypers->shape; + // Update hyperparameters with posterior random Gibbs sampling + hypers->mean = stan::math::normal_rng(mu_n, sig_n, rng); + hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); + hypers->scale = stan::math::gamma_rng(a_n, b_n, rng); + } + + else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +NNIG::State NNIGHierarchy::draw(const NNIG::Hyperparams ¶ms) { + auto &rng = bayesmix::Rng::Instance().get(); + NNIG::State out; + out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); + out.mean = stan::math::normal_rng(params.mean, + sqrt(state.var / params.var_scaling), rng); + return out; +} + +void NNIGHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) { + if (add) { + data_sum += datum(0); + data_sum_squares += datum(0) * datum(0); + } else { + data_sum -= datum(0); + data_sum_squares -= datum(0) * datum(0); + } +} + +void NNIGHierarchy::clear_summary_statistics() { + data_sum = 0; + data_sum_squares = 0; +} + +NNIG::Hyperparams NNIGHierarchy::compute_posterior_hypers() const { + // Initialize relevant variables + if (card == 0) { // no update possible + return *hypers; + } + // Compute posterior hyperparameters + NNIG::Hyperparams post_params; + double y_bar = data_sum / (1.0 * card); // sample mean + double ss = data_sum_squares - card * y_bar * y_bar; + post_params.mean = (hypers->var_scaling * hypers->mean + data_sum) / + (hypers->var_scaling + card); + post_params.var_scaling = hypers->var_scaling + card; + post_params.shape = hypers->shape + 0.5 * card; + post_params.scale = hypers->scale + 0.5 * ss + + 0.5 * hypers->var_scaling * card * + (y_bar - hypers->mean) * (y_bar - hypers->mean) / + (card + hypers->var_scaling); + return post_params; +} + +void NNIGHierarchy::set_state_from_proto( + const google::protobuf::Message &state_) { + auto &statecast = downcast_state(state_); + state.mean = statecast.uni_ls_state().mean(); + state.var = statecast.uni_ls_state().var(); + set_card(statecast.cardinality()); +} + +std::shared_ptr +NNIGHierarchy::get_state_proto() const { + bayesmix::UniLSState state_; + state_.set_mean(state.mean); + state_.set_var(state.var); + + auto out = std::make_shared(); + out->mutable_uni_ls_state()->CopyFrom(state_); + return out; +} + +void NNIGHierarchy::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).nnig_state(); + hypers->mean = hyperscast.mean(); + hypers->var_scaling = hyperscast.var_scaling(); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); +} + +std::shared_ptr +NNIGHierarchy::get_hypers_proto() const { + bayesmix::NIGDistribution hypers_; + hypers_.set_mean(hypers->mean); + hypers_.set_var_scaling(hypers->var_scaling); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); + + auto out = std::make_shared(); + out->mutable_nnig_state()->CopyFrom(hypers_); + return out; +} diff --git a/src/hierarchies/.old/nnig_hierarchy.h b/src/hierarchies/.old/nnig_hierarchy.h new file mode 100644 index 000000000..7911691e9 --- /dev/null +++ b/src/hierarchies/.old/nnig_hierarchy.h @@ -0,0 +1,122 @@ +#ifndef BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "conjugate_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "hierarchy_prior.pb.h" + +//! Conjugate Normal Normal-InverseGamma hierarchy for univariate data. + +//! This class represents a hierarchical model where data are distributed +//! according to a normal likelihood, the parameters of which have a +//! Normal-InverseGamma centering distribution. That is: +//! f(x_i|mu,sig) = N(mu,sig^2) +//! (mu,sig^2) ~ N-IG(mu0, lambda0, alpha0, beta0) +//! The state is composed of mean and variance. The state hyperparameters, +//! contained in the Hypers object, are (mu_0, lambda0, alpha0, beta0), all +//! scalar values. Note that this hierarchy is conjugate, thus the marginal +//! distribution is available in closed form. For more information, please +//! refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and +//! `ConjugateHierarchy`. + +namespace NNIG { +//! Custom container for State values +struct State { + double mean, var; +}; + +//! Custom container for Hyperparameters values +struct Hyperparams { + double mean, var_scaling, shape, scale; +}; + +}; // namespace NNIG + +class NNIGHierarchy + : public ConjugateHierarchy { + public: + NNIGHierarchy() = default; + ~NNIGHierarchy() = default; + + //! Updates hyperparameter values given a vector of cluster states + void update_hypers(const std::vector + &states) override; + + //! Updates state values using the given (prior or posterior) hyperparameters + NNIG::State draw(const NNIG::Hyperparams ¶ms); + + //! Resets summary statistics for this cluster + void clear_summary_statistics() override; + + //! Returns the Protobuf ID associated to this class + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::NNIG; + } + + //! Read and set state values from a given Protobuf message + void set_state_from_proto(const google::protobuf::Message &state_) override; + + //! Read and set hyperparameter values from a given Protobuf message + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type + std::shared_ptr get_state_proto() + const override; + + //! Writes current value of hyperparameters to a Protobuf message and + //! return a shared_ptr. + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::HierarchyHypers message by adding the appropriate type + std::shared_ptr get_hypers_proto() + const override; + + //! Computes and return posterior hypers given data currently in this cluster + NNIG::Hyperparams compute_posterior_hypers() const; + + //! Returns whether the hierarchy models multivariate data or not + bool is_multivariate() const override { return false; } + + protected: + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + double like_lpdf(const Eigen::RowVectorXd &datum) const override; + + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + double marg_lpdf(const NNIG::Hyperparams ¶ms, + const Eigen::RowVectorXd &datum) const override; + + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param add Whether the datum is being added or removed + void update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) override; + + //! Initializes state parameters to appropriate values + void initialize_state() override; + + //! Initializes hierarchy hyperparameters to appropriate values + void initialize_hypers() override; + + //! Sum of data points currently belonging to the cluster + double data_sum = 0; + + //! Sum of squared data points currently belonging to the cluster + double data_sum_squares = 0; +}; + +#endif // BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ diff --git a/src/hierarchies/.old/nnw_hierarchy.cc b/src/hierarchies/.old/nnw_hierarchy.cc new file mode 100644 index 000000000..e1e58275f --- /dev/null +++ b/src/hierarchies/.old/nnw_hierarchy.cc @@ -0,0 +1,373 @@ +#include "nnw_hierarchy.h" + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "hierarchy_prior.pb.h" +#include "ls_state.pb.h" +#include "matrix.pb.h" +#include "src/utils/distributions.h" +#include "src/utils/eigen_utils.h" +#include "src/utils/proto_utils.h" +#include "src/utils/rng.h" + +double NNWHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { + return bayesmix::multi_normal_prec_lpdf(datum, state.mean, state.prec_chol, + state.prec_logdet); +} + +double NNWHierarchy::marg_lpdf(const NNW::Hyperparams ¶ms, + const Eigen::RowVectorXd &datum) const { + NNW::Hyperparams pred_params = get_predictive_t_parameters(params); + Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); + double logdet = 2 * log(diag.array()).sum(); + + return bayesmix::multi_student_t_invscale_lpdf( + datum, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, + logdet); +} + +Eigen::VectorXd NNWHierarchy::like_lpdf_grid( + const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { + // Custom, optimized grid method + return bayesmix::multi_normal_prec_lpdf_grid( + data, state.mean, state.prec_chol, state.prec_logdet); +} + +Eigen::VectorXd NNWHierarchy::prior_pred_lpdf_grid( + const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { + // Custom, optimized grid method + NNW::Hyperparams pred_params = get_predictive_t_parameters(*hypers); + Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); + double logdet = 2 * log(diag.array()).sum(); + + return bayesmix::multi_student_t_invscale_lpdf_grid( + data, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, + logdet); +} + +Eigen::VectorXd NNWHierarchy::conditional_pred_lpdf_grid( + const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { + // Custom, optimized grid method + NNW::Hyperparams pred_params = + get_predictive_t_parameters(compute_posterior_hypers()); + Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); + double logdet = 2 * log(diag.array()).sum(); + + return bayesmix::multi_student_t_invscale_lpdf_grid( + data, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, + logdet); +} + +void NNWHierarchy::initialize_state() { + state.mean = hypers->mean; + write_prec_to_state( + hypers->var_scaling * Eigen::MatrixXd::Identity(dim, dim), &state); +} + +void NNWHierarchy::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); + dim = hypers->mean.size(); + hypers->var_scaling = prior->fixed_values().var_scaling(); + hypers->scale = bayesmix::to_eigen(prior->fixed_values().scale()); + hypers->deg_free = prior->fixed_values().deg_free(); + // Check validity + if (hypers->var_scaling <= 0) { + throw std::invalid_argument("Variance-scaling parameter must be > 0"); + } + if (dim != hypers->scale.rows()) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + if (hypers->deg_free <= dim - 1) { + throw std::invalid_argument("Degrees of freedom parameter is not valid"); + } + } + + else if (prior->has_normal_mean_prior()) { + // Get hyperparameters + Eigen::VectorXd mu00 = + bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().mean()); + dim = mu00.size(); + Eigen::MatrixXd sigma00 = + bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().var()); + double lambda0 = prior->normal_mean_prior().var_scaling(); + Eigen::MatrixXd tau0 = + bayesmix::to_eigen(prior->normal_mean_prior().scale()); + double nu0 = prior->normal_mean_prior().deg_free(); + // Check validity + unsigned int dim = mu00.size(); + if (sigma00.rows() != dim or tau0.rows() != dim) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + bayesmix::check_spd(sigma00); + if (lambda0 <= 0) { + throw std::invalid_argument("Variance-scaling parameter must be > 0"); + } + bayesmix::check_spd(tau0); + if (nu0 <= dim - 1) { + throw std::invalid_argument("Degrees of freedom parameter is not valid"); + } + // Set initial values + hypers->mean = mu00; + hypers->var_scaling = lambda0; + hypers->scale = tau0; + hypers->deg_free = nu0; + } + + else if (prior->has_ngiw_prior()) { + // Get hyperparameters: + // for mu0 + Eigen::VectorXd mu00 = + bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); + dim = mu00.size(); + Eigen::MatrixXd sigma00 = + bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); + // for lambda0 + double alpha00 = prior->ngiw_prior().var_scaling_prior().shape(); + double beta00 = prior->ngiw_prior().var_scaling_prior().rate(); + // for tau0 + double nu00 = prior->ngiw_prior().scale_prior().deg_free(); + Eigen::MatrixXd tau00 = + bayesmix::to_eigen(prior->ngiw_prior().scale_prior().scale()); + // for nu0 + double nu0 = prior->ngiw_prior().deg_free(); + // Check validity: + // dimensionality + if (sigma00.rows() != dim or tau00.rows() != dim) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + // for mu0 + bayesmix::check_spd(sigma00); + // for lambda0 + if (alpha00 <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (beta00 <= 0) { + throw std::invalid_argument("Rate parameter must be > 0"); + } + // for tau0 + if (nu00 <= 0) { + throw std::invalid_argument("Degrees of freedom parameter must be > 0"); + } + bayesmix::check_spd(tau00); + // check nu0 + if (nu0 <= dim - 1) { + throw std::invalid_argument("Degrees of freedom parameter is not valid"); + } + // Set initial values + hypers->mean = mu00; + hypers->var_scaling = alpha00 / beta00; + hypers->scale = tau00 / (nu00 + dim + 1); + hypers->deg_free = nu0; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } + hypers->scale_inv = stan::math::inverse_spd(hypers->scale); + hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); +} + +void NNWHierarchy::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + if (prior->has_fixed_values()) { + return; + } + + else if (prior->has_normal_mean_prior()) { + // Get hyperparameters + Eigen::VectorXd mu00 = + bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().mean()); + Eigen::MatrixXd sigma00 = + bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().var()); + double lambda0 = prior->normal_mean_prior().var_scaling(); + // Compute posterior hyperparameters + Eigen::MatrixXd sigma00inv = stan::math::inverse_spd(sigma00); + Eigen::MatrixXd prec = Eigen::MatrixXd::Zero(dim, dim); + Eigen::VectorXd num = Eigen::MatrixXd::Zero(dim, 1); + for (auto &st : states) { + Eigen::MatrixXd prec_i = bayesmix::to_eigen(st.multi_ls_state().prec()); + prec += prec_i; + num += prec_i * bayesmix::to_eigen(st.multi_ls_state().mean()); + } + prec = hypers->var_scaling * prec + sigma00inv; + num = hypers->var_scaling * num + sigma00inv * mu00; + Eigen::VectorXd mu_n = prec.llt().solve(num); + // Update hyperparameters with posterior sampling + hypers->mean = stan::math::multi_normal_prec_rng(mu_n, prec, rng); + } + + else if (prior->has_ngiw_prior()) { + // Get hyperparameters: + // for mu0 + Eigen::VectorXd mu00 = + bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); + Eigen::MatrixXd sigma00 = + bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); + // for lambda0 + double alpha00 = prior->ngiw_prior().var_scaling_prior().shape(); + double beta00 = prior->ngiw_prior().var_scaling_prior().rate(); + // for tau0 + double nu00 = prior->ngiw_prior().scale_prior().deg_free(); + Eigen::MatrixXd tau00 = + bayesmix::to_eigen(prior->ngiw_prior().scale_prior().scale()); + // Compute posterior hyperparameters + Eigen::MatrixXd sigma00inv = stan::math::inverse_spd(sigma00); + Eigen::MatrixXd tau_n = Eigen::MatrixXd::Zero(dim, dim); + Eigen::VectorXd num = Eigen::MatrixXd::Zero(dim, 1); + double beta_n = 0.0; + for (auto &st : states) { + Eigen::VectorXd mean = bayesmix::to_eigen(st.multi_ls_state().mean()); + Eigen::MatrixXd prec = bayesmix::to_eigen(st.multi_ls_state().prec()); + tau_n += prec; + num += prec * mean; + beta_n += + (hypers->mean - mean).transpose() * prec * (hypers->mean - mean); + } + Eigen::MatrixXd prec_n = hypers->var_scaling * tau_n + sigma00inv; + tau_n += tau00; + num = hypers->var_scaling * num + sigma00inv * mu00; + beta_n = beta00 + 0.5 * beta_n; + Eigen::MatrixXd sig_n = stan::math::inverse_spd(prec_n); + Eigen::VectorXd mu_n = sig_n * num; + double alpha_n = alpha00 + 0.5 * states.size(); + double nu_n = nu00 + states.size() * hypers->deg_free; + // Update hyperparameters with posterior random Gibbs sampling + hypers->mean = stan::math::multi_normal_rng(mu_n, sig_n, rng); + hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); + hypers->scale = stan::math::inv_wishart_rng(nu_n, tau_n, rng); + hypers->scale_inv = stan::math::inverse_spd(hypers->scale); + hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); + } + + else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +NNW::State NNWHierarchy::draw(const NNW::Hyperparams ¶ms) { + auto &rng = bayesmix::Rng::Instance().get(); + Eigen::MatrixXd tau_new = + stan::math::wishart_rng(params.deg_free, params.scale, rng); + // Update state + NNW::State out; + out.mean = stan::math::multi_normal_prec_rng( + params.mean, tau_new * params.var_scaling, rng); + write_prec_to_state(tau_new, &out); + return out; +} + +void NNWHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) { + if (add) { + data_sum += datum.transpose(); + data_sum_squares += datum.transpose() * datum; + } else { + data_sum -= datum.transpose(); + data_sum_squares -= datum.transpose() * datum; + } +} + +void NNWHierarchy::clear_summary_statistics() { + data_sum = Eigen::VectorXd::Zero(dim); + data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); +} + +NNW::Hyperparams NNWHierarchy::compute_posterior_hypers() const { + if (card == 0) { // no update possible + return *hypers; + } + // Compute posterior hyperparameters + NNW::Hyperparams post_params; + post_params.var_scaling = hypers->var_scaling + card; + post_params.deg_free = hypers->deg_free + card; + Eigen::VectorXd mubar = data_sum.array() / card; // sample mean + post_params.mean = (hypers->var_scaling * hypers->mean + card * mubar) / + (hypers->var_scaling + card); + // Compute tau_n + Eigen::MatrixXd tau_temp = + data_sum_squares - card * mubar * mubar.transpose(); + tau_temp += (card * hypers->var_scaling / (card + hypers->var_scaling)) * + (mubar - hypers->mean) * (mubar - hypers->mean).transpose(); + post_params.scale_inv = tau_temp + hypers->scale_inv; + post_params.scale = stan::math::inverse_spd(post_params.scale_inv); + post_params.scale_chol = + Eigen::LLT(post_params.scale).matrixU(); + return post_params; +} + +void NNWHierarchy::set_state_from_proto( + const google::protobuf::Message &state_) { + auto &statecast = downcast_state(state_); + state.mean = to_eigen(statecast.multi_ls_state().mean()); + state.prec = to_eigen(statecast.multi_ls_state().prec()); + state.prec_chol = to_eigen(statecast.multi_ls_state().prec_chol()); + Eigen::VectorXd diag = state.prec_chol.diagonal(); + state.prec_logdet = 2 * log(diag.array()).sum(); + set_card(statecast.cardinality()); +} + +std::shared_ptr +NNWHierarchy::get_state_proto() const { + bayesmix::MultiLSState state_; + bayesmix::to_proto(state.mean, state_.mutable_mean()); + bayesmix::to_proto(state.prec, state_.mutable_prec()); + bayesmix::to_proto(state.prec_chol, state_.mutable_prec_chol()); + + auto out = std::make_shared(); + out->mutable_multi_ls_state()->CopyFrom(state_); + return out; +} + +void NNWHierarchy::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).nnw_state(); + hypers->mean = to_eigen(hyperscast.mean()); + hypers->var_scaling = hyperscast.var_scaling(); + hypers->deg_free = hyperscast.deg_free(); + hypers->scale = to_eigen(hyperscast.scale()); +} + +std::shared_ptr +NNWHierarchy::get_hypers_proto() const { + bayesmix::NWDistribution hypers_; + bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); + hypers_.set_var_scaling(hypers->var_scaling); + hypers_.set_deg_free(hypers->deg_free); + bayesmix::to_proto(hypers->scale, hypers_.mutable_scale()); + + auto out = std::make_shared(); + out->mutable_nnw_state()->CopyFrom(hypers_); + return out; +} + +void NNWHierarchy::write_prec_to_state(const Eigen::MatrixXd &prec_, + NNW::State *out) { + out->prec = prec_; + // Update prec utilities + out->prec_chol = Eigen::LLT(prec_).matrixU(); + Eigen::VectorXd diag = out->prec_chol.diagonal(); + out->prec_logdet = 2 * log(diag.array()).sum(); +} + +NNW::Hyperparams NNWHierarchy::get_predictive_t_parameters( + const NNW::Hyperparams ¶ms) const { + // Compute dof and scale of marginal distribution + double nu_n = params.deg_free - dim + 1; + double coeff = (params.var_scaling + 1) / (params.var_scaling * nu_n); + Eigen::MatrixXd scale_chol_n = params.scale_chol / std::sqrt(coeff); + + NNW::Hyperparams out; + out.mean = params.mean; + out.deg_free = nu_n; + out.scale_chol = scale_chol_n; + return out; +} diff --git a/src/hierarchies/.old/nnw_hierarchy.h b/src/hierarchies/.old/nnw_hierarchy.h new file mode 100644 index 000000000..1b149d422 --- /dev/null +++ b/src/hierarchies/.old/nnw_hierarchy.h @@ -0,0 +1,168 @@ +#ifndef BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "conjugate_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "hierarchy_prior.pb.h" + +//! Normal Normal-Wishart hierarchy for multivariate data. + +//! This class represents a hierarchy, i.e. a cluster, whose multivariate data +//! are distributed according to a multinomial normal likelihood, the +//! parameters of which have a Normal-Wishart centering distribution. That is: +//! f(x_i|mu,tau) = N(mu,tau^{-1}) +//! (mu,tau) ~ NW(mu0, lambda0, tau0, nu0) +//! The state is composed of mean and precision matrix. The Cholesky factor and +//! log-determinant of the latter are also included in the container for +//! efficiency reasons. The state's hyperparameters, contained in the Hypers +//! object, are (mu0, lambda0, tau0, nu0), which are respectively vector, +//! scalar, matrix, and scalar. Note that this hierarchy is conjugate, thus the +//! marginal distribution is available in closed form. For more information, +//! please refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and +//! `ConjugateHierarchy`. + +namespace NNW { +//! Custom container for State values +struct State { + Eigen::VectorXd mean; + Eigen::MatrixXd prec; + Eigen::MatrixXd prec_chol; + double prec_logdet; +}; + +//! Custom container for Hyperparameters values +struct Hyperparams { + Eigen::VectorXd mean; + double var_scaling; + double deg_free; + Eigen::MatrixXd scale; + Eigen::MatrixXd scale_inv; + Eigen::MatrixXd scale_chol; +}; +} // namespace NNW + +class NNWHierarchy + : public ConjugateHierarchy { + public: + NNWHierarchy() = default; + ~NNWHierarchy() = default; + + // EVALUATION FUNCTIONS FOR GRIDS OF POINTS + //! Evaluates the log-likelihood of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = + Eigen::MatrixXd(0, 0)) const override; + + //! Evaluates the log-prior predictive distr. of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + Eigen::VectorXd prior_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, + 0)) const override; + + //! Evaluates the log-prior predictive distr. of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf + Eigen::VectorXd conditional_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, + 0)) const override; + + //! Updates hyperparameter values given a vector of cluster states + void update_hypers(const std::vector + &states) override; + + //! Updates state values using the given (prior or posterior) hyperparameters + NNW::State draw(const NNW::Hyperparams ¶ms); + + //! Resets summary statistics for this cluster + void clear_summary_statistics() override; + + //! Returns the Protobuf ID associated to this class + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::NNW; + } + + //! Computes and return posterior hypers given data currently in this cluster + NNW::Hyperparams compute_posterior_hypers() const; + + //! Read and set state values from a given Protobuf message + void set_state_from_proto(const google::protobuf::Message &state_) override; + + //! Read and set hyperparameter values from a given Protobuf message + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type + std::shared_ptr get_state_proto() + const override; + + //! Writes current value of hyperparameters to a Protobuf message and + //! return a shared_ptr. + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::HierarchyHypers message by adding the appropriate type + std::shared_ptr get_hypers_proto() + const override; + + //! Returns whether the hierarchy models multivariate data or not + bool is_multivariate() const override { return true; } + + protected: + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + double like_lpdf(const Eigen::RowVectorXd &datum) const override; + + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + double marg_lpdf(const NNW::Hyperparams ¶ms, + const Eigen::RowVectorXd &datum) const override; + + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param add Whether the datum is being added or removed + void update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) override; + + //! Writes prec and its utilities to the given state object by pointer + void write_prec_to_state(const Eigen::MatrixXd &prec_, NNW::State *out); + + //! Returns parameters for the predictive Student's t distribution + NNW::Hyperparams get_predictive_t_parameters( + const NNW::Hyperparams ¶ms) const; + + //! Initializes state parameters to appropriate values + void initialize_state() override; + + //! Initializes hierarchy hyperparameters to appropriate values + void initialize_hypers() override; + + //! Dimension of data space + unsigned int dim; + + //! Sum of data points currently belonging to the cluster + Eigen::VectorXd data_sum; + + //! Sum of squared data points currently belonging to the cluster + Eigen::MatrixXd data_sum_squares; +}; + +#endif // BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ diff --git a/src/hierarchies/.old/nnxig_hierarchy.cc b/src/hierarchies/.old/nnxig_hierarchy.cc new file mode 100644 index 000000000..f7bc62a16 --- /dev/null +++ b/src/hierarchies/.old/nnxig_hierarchy.cc @@ -0,0 +1,152 @@ +#include "nnxig_hierarchy.h" + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "hierarchy_prior.pb.h" +#include "ls_state.pb.h" +#include "src/utils/rng.h" + +double NNxIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); +} + +void NNxIGHierarchy::initialize_state() { + state.mean = hypers->mean; + state.var = hypers->scale / (hypers->shape + 1); +} + +void NNxIGHierarchy::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mean = prior->fixed_values().mean(); + hypers->var = prior->fixed_values().var(); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); + + // Check validity + if (hypers->var <= 0) { + throw std::invalid_argument("Variance parameter must be > 0"); + } + if (hypers->shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers->scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NNxIGHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) { + if (add) { + data_sum += datum(0); + data_sum_squares += datum(0) * datum(0); + } else { + data_sum -= datum(0); + data_sum_squares -= datum(0) * datum(0); + } +} + +void NNxIGHierarchy::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NNxIGHierarchy::clear_summary_statistics() { + data_sum = 0; + data_sum_squares = 0; +} + +void NNxIGHierarchy::set_state_from_proto( + const google::protobuf::Message &state_) { + auto &statecast = downcast_state(state_); + state.mean = statecast.uni_ls_state().mean(); + state.var = statecast.uni_ls_state().var(); + set_card(statecast.cardinality()); +} + +void NNxIGHierarchy::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).nnxig_state(); + hypers->mean = hyperscast.mean(); + hypers->var = hyperscast.var(); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); +} + +std::shared_ptr +NNxIGHierarchy::get_state_proto() const { + bayesmix::UniLSState state_; + state_.set_mean(state.mean); + state_.set_var(state.var); + + auto out = std::make_shared(); + out->mutable_uni_ls_state()->CopyFrom(state_); + return out; +} + +std::shared_ptr +NNxIGHierarchy::get_hypers_proto() const { + bayesmix::NxIGDistribution hypers_; + hypers_.set_mean(hypers->mean); + hypers_.set_var(hypers->var); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); + + auto out = std::make_shared(); + out->mutable_nnxig_state()->CopyFrom(hypers_); + return out; +} + +void NNxIGHierarchy::sample_full_cond(bool update_params) { + if (this->card == 0) { + // No posterior update possible + sample_prior(); + } else { + NNxIG::Hyperparams params = + update_params ? compute_posterior_hypers() : posterior_hypers; + state = draw(params); + } +} + +NNxIG::State NNxIGHierarchy::draw(const NNxIG::Hyperparams ¶ms) { + auto &rng = bayesmix::Rng::Instance().get(); + NNxIG::State out; + out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); + out.mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); + return out; +} + +NNxIG::Hyperparams NNxIGHierarchy::compute_posterior_hypers() const { + // Initialize relevant variables + if (card == 0) { // no update possible + return *hypers; + } + // Compute posterior hyperparameters + NNxIG::Hyperparams post_params; + double var_y = data_sum_squares - 2 * state.mean * data_sum + + card * state.mean * state.mean; + post_params.mean = (hypers->var * data_sum + state.var * hypers->mean) / + (card * hypers->var + state.var); + post_params.var = + (state.var * hypers->var) / (card * hypers->var + state.var); + post_params.shape = hypers->shape + 0.5 * card; + post_params.scale = hypers->scale + 0.5 * var_y; + return post_params; +} + +void NNxIGHierarchy::save_posterior_hypers() { + posterior_hypers = compute_posterior_hypers(); +} diff --git a/src/hierarchies/.old/nnxig_hierarchy.h b/src/hierarchies/.old/nnxig_hierarchy.h new file mode 100644 index 000000000..de5e878f6 --- /dev/null +++ b/src/hierarchies/.old/nnxig_hierarchy.h @@ -0,0 +1,120 @@ +#ifndef BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "hierarchy_prior.pb.h" + +//! Non Conjugate Normal Normal-InverseGamma hierarchy for univariate data. + +//! This class represents a hierarchical model where data are distributed +//! according to a normal likelihood, the parameters of which have a +//! Normal-InverseGamma centering distribution. That is: +//! f(x_i|mu,sig) = N(mu,sig^2) +//! mu ~ N(mu0, sigma0) +//! sig^2 ~ IG(alpha0, beta0) +//! The state is composed of mean and variance. The state hyperparameters, +//! contained in the Hypers object, are (mu0, sigma0, alpha0, beta0), all +//! scalar values. Note that this hierarchy is non conjugate. + +namespace NNxIG { +//! Custom container for State values +struct State { + double mean, var; +}; + +//! Custom container for Hyperparameters values +struct Hyperparams { + double mean, var, shape, scale; +}; + +}; // namespace NNxIG + +class NNxIGHierarchy + : public BaseHierarchy { + public: + NNxIGHierarchy() = default; + ~NNxIGHierarchy() = default; + + //! Updates hyperparameter values given a vector of cluster states + void update_hypers(const std::vector + &states) override; + + //! Updates state values using the given (prior or posterior) hyperparameters + NNxIG::State draw(const NNxIG::Hyperparams ¶ms); + + //! Generates new state values from the centering posterior distribution + //! @param update_params Save posterior hypers after the computation? + void sample_full_cond(bool update_params = true) override; + + //! Saves posterior hyperparameters to the corresponding class member + void save_posterior_hypers(); + + //! Resets summary statistics for this cluster + void clear_summary_statistics() override; + + //! Returns the Protobuf ID associated to this class + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::NNxIG; + } + + //! Read and set state values from a given Protobuf message + void set_state_from_proto(const google::protobuf::Message &state_) override; + + //! Read and set hyperparameter values from a given Protobuf message + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type + std::shared_ptr get_state_proto() + const override; + + //! Writes current value of hyperparameters to a Protobuf message and + //! return a shared_ptr. + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::HierarchyHypers message by adding the appropriate type + std::shared_ptr get_hypers_proto() + const override; + + //! Computes and return posterior hypers given data currently in this cluster + NNxIG::Hyperparams compute_posterior_hypers() const; + + //! Returns whether the hierarchy models multivariate data or not + bool is_multivariate() const override { return false; } + + protected: + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf + double like_lpdf(const Eigen::RowVectorXd &datum) const override; + + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param add Whether the datum is being added or removed + void update_summary_statistics(const Eigen::RowVectorXd &datum, + bool add) override; + + //! Initializes state parameters to appropriate values + void initialize_state() override; + + //! Initializes hierarchy hyperparameters to appropriate values + void initialize_hypers() override; + + //! Sum of data points currently belonging to the cluster + double data_sum = 0; + + //! Sum of squared data points currently belonging to the cluster + double data_sum_squares = 0; +}; + +#endif // BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ From 4caa7da2d46fd0916d0a22d85fed3070c8693163 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:25:49 +0100 Subject: [PATCH 043/317] Added tests for new NNIG hierarchy pattern and for PriorModel sample --- test/CMakeLists.txt | 2 +- test/hierarchies.cc | 424 ++++++++++++++++++++++--------------------- test/prior_models.cc | 20 ++ 3 files changed, 234 insertions(+), 212 deletions(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 533330ddd..5a238990f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,7 +18,7 @@ FetchContent_MakeAvailable(googletest) add_executable(test_bayesmix $ # write_proto.cc # proto_utils.cc - # hierarchies.cc + hierarchies.cc # lpdf.cc # priors.cc # eigen_utils.cc diff --git a/test/hierarchies.cc b/test/hierarchies.cc index 2612eeb35..a9b98d79b 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -5,10 +5,10 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" -#include "src/hierarchies/lin_reg_uni_hierarchy.h" +// #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" -#include "src/hierarchies/nnw_hierarchy.h" -#include "src/hierarchies/nnxig_hierarchy.h" +// #include "src/hierarchies/nnw_hierarchy.h" +// #include "src/hierarchies/nnxig_hierarchy.h" #include "src/utils/proto_utils.h" #include "src/utils/rng.h" @@ -69,211 +69,213 @@ TEST(nnighierarchy, sample_given_data) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -TEST(nnwhierarchy, draw) { - auto hier = std::make_shared(); - bayesmix::NNWPrior prior; - Eigen::Vector2d mu0; - mu0 << 5.5, 5.5; - bayesmix::Vector mu0_proto; - bayesmix::to_proto(mu0, &mu0_proto); - double lambda0 = 0.2; - double nu0 = 5.0; - Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; - bayesmix::Matrix tau0_proto; - bayesmix::to_proto(tau0, &tau0_proto); - *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; - prior.mutable_fixed_values()->set_var_scaling(lambda0); - prior.mutable_fixed_values()->set_deg_free(nu0); - *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; - hier->get_mutable_prior()->CopyFrom(prior); - hier->initialize(); - - auto hier2 = hier->clone(); - hier2->sample_prior(); - - bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); - bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); - hier->write_state_to_proto(clusval); - hier2->write_state_to_proto(clusval2); - - ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -} - -TEST(nnwhierarchy, sample_given_data) { - auto hier = std::make_shared(); - bayesmix::NNWPrior prior; - Eigen::Vector2d mu0; - mu0 << 5.5, 5.5; - bayesmix::Vector mu0_proto; - bayesmix::to_proto(mu0, &mu0_proto); - double lambda0 = 0.2; - double nu0 = 5.0; - Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; - bayesmix::Matrix tau0_proto; - bayesmix::to_proto(tau0, &tau0_proto); - *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; - prior.mutable_fixed_values()->set_var_scaling(lambda0); - prior.mutable_fixed_values()->set_deg_free(nu0); - *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; - hier->get_mutable_prior()->CopyFrom(prior); - hier->initialize(); - - Eigen::RowVectorXd datum(2); - datum << 4.5, 4.5; - - auto hier2 = hier->clone(); - hier2->add_datum(0, datum, false); - hier2->sample_full_cond(); - - bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); - bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); - hier->write_state_to_proto(clusval); - hier2->write_state_to_proto(clusval2); - - ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -} - -TEST(lin_reg_uni_hierarchy, state_read_write) { - Eigen::Vector2d beta; - beta << 2, -1; - double sigma2 = 9; - - bayesmix::LinRegUniLSState ls; - bayesmix::to_proto(beta, ls.mutable_regression_coeffs()); - ls.set_var(sigma2); - - bayesmix::AlgorithmState::ClusterState state; - state.mutable_lin_reg_uni_ls_state()->CopyFrom(ls); - - LinRegUniHierarchy hier; - hier.set_state_from_proto(state); - - ASSERT_EQ(hier.get_state().regression_coeffs, beta); - ASSERT_EQ(hier.get_state().var, sigma2); - - bayesmix::AlgorithmState outt; - bayesmix::AlgorithmState::ClusterState* out = outt.add_cluster_states(); - hier.write_state_to_proto(out); - ASSERT_EQ(beta, bayesmix::to_eigen( - out->lin_reg_uni_ls_state().regression_coeffs())); - ASSERT_EQ(sigma2, out->lin_reg_uni_ls_state().var()); -} - -TEST(lin_reg_uni_hierarchy, misc) { - // Build data - int n = 5; - int dim = 2; - Eigen::Vector2d beta_true; - beta_true << 10.0, 10.0; - Eigen::MatrixXd cov = Eigen::MatrixXd::Random(n, dim); // each in U[-1,1] - double sigma2 = 1.0; - Eigen::VectorXd data(n); - auto& rng = bayesmix::Rng::Instance().get(); - for (int i = 0; i < n; i++) { - data(i) = stan::math::normal_rng(cov.row(i).dot(beta_true), sigma2, rng); - } - // Initialize objects - LinRegUniHierarchy hier; - bayesmix::LinRegUniPrior prior; - // Create prior parameters - Eigen::Vector2d beta0 = 0 * beta_true; - bayesmix::Vector beta0_proto; - bayesmix::to_proto(beta0, &beta0_proto); - auto Lambda0 = Eigen::Matrix2d::Identity(); - bayesmix::Matrix Lambda0_proto; - bayesmix::to_proto(Lambda0, &Lambda0_proto); - double a0 = 2.0; - double b0 = 1.0; - // Initialize hierarchy - *prior.mutable_fixed_values()->mutable_mean() = beta0_proto; - *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; - prior.mutable_fixed_values()->set_shape(a0); - prior.mutable_fixed_values()->set_scale(b0); - hier.get_mutable_prior()->CopyFrom(prior); - hier.initialize(); - // Extract hypers for reading test - bayesmix::AlgorithmState::HierarchyHypers out; - hier.write_hypers_to_proto(&out); - ASSERT_EQ(beta0, bayesmix::to_eigen(out.lin_reg_uni_state().mean())); - ASSERT_EQ(Lambda0, - bayesmix::to_eigen(out.lin_reg_uni_state().var_scaling())); - ASSERT_EQ(a0, out.lin_reg_uni_state().shape()); - ASSERT_EQ(b0, out.lin_reg_uni_state().scale()); - // Add data - for (int i = 0; i < n; i++) { - hier.add_datum(i, data.row(i), false, cov.row(i)); - } - // Check summary statistics - // for (int i = 0; i < dim; i++) { - // for (int j = 0; j < dim; j++) { - // ASSERT_DOUBLE_EQ(hier.get_covar_sum_squares()(i, j), - // (cov.transpose() * cov)(i, j)); - // } - // ASSERT_DOUBLE_EQ(hier.get_mixed_prod()(i), (cov.transpose() * data)(i)); - // } - // Compute and check posterior values - hier.sample_full_cond(); - auto state = hier.get_state(); - for (int i = 0; i < dim; i++) { - ASSERT_GT(state.regression_coeffs(i), beta0(i)); - } -} - -TEST(nnxighierarchy, draw) { - auto hier = std::make_shared(); - bayesmix::NNxIGPrior prior; - double mu0 = 5.0; - double var0 = 1.0; - double alpha0 = 2.0; - double beta0 = 2.0; - prior.mutable_fixed_values()->set_mean(mu0); - prior.mutable_fixed_values()->set_var(var0); - prior.mutable_fixed_values()->set_shape(alpha0); - prior.mutable_fixed_values()->set_scale(beta0); - hier->get_mutable_prior()->CopyFrom(prior); - hier->initialize(); - - auto hier2 = hier->clone(); - hier2->sample_prior(); - - bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); - bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); - hier->write_state_to_proto(clusval); - hier2->write_state_to_proto(clusval2); - - ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -} - -TEST(nnxighierarchy, sample_given_data) { - auto hier = std::make_shared(); - bayesmix::NNxIGPrior prior; - double mu0 = 5.0; - double var0 = 1.0; - double alpha0 = 2.0; - double beta0 = 2.0; - prior.mutable_fixed_values()->set_mean(mu0); - prior.mutable_fixed_values()->set_var(var0); - prior.mutable_fixed_values()->set_shape(alpha0); - prior.mutable_fixed_values()->set_scale(beta0); - hier->get_mutable_prior()->CopyFrom(prior); - - hier->initialize(); - - Eigen::VectorXd datum(1); - datum << 4.5; - - auto hier2 = hier->clone(); - hier2->add_datum(0, datum, false); - hier2->sample_full_cond(); - - bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); - bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); - hier->write_state_to_proto(clusval); - hier2->write_state_to_proto(clusval2); - - ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -} +// TEST(nnwhierarchy, draw) { +// auto hier = std::make_shared(); +// bayesmix::NNWPrior prior; +// Eigen::Vector2d mu0; +// mu0 << 5.5, 5.5; +// bayesmix::Vector mu0_proto; +// bayesmix::to_proto(mu0, &mu0_proto); +// double lambda0 = 0.2; +// double nu0 = 5.0; +// Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; +// bayesmix::Matrix tau0_proto; +// bayesmix::to_proto(tau0, &tau0_proto); +// *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; +// prior.mutable_fixed_values()->set_var_scaling(lambda0); +// prior.mutable_fixed_values()->set_deg_free(nu0); +// *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; +// hier->get_mutable_prior()->CopyFrom(prior); +// hier->initialize(); + +// auto hier2 = hier->clone(); +// hier2->sample_prior(); + +// bayesmix::AlgorithmState out; +// bayesmix::AlgorithmState::ClusterState* clusval = +// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 +// = out.add_cluster_states(); hier->write_state_to_proto(clusval); +// hier2->write_state_to_proto(clusval2); + +// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +// } + +// TEST(nnwhierarchy, sample_given_data) { +// auto hier = std::make_shared(); +// bayesmix::NNWPrior prior; +// Eigen::Vector2d mu0; +// mu0 << 5.5, 5.5; +// bayesmix::Vector mu0_proto; +// bayesmix::to_proto(mu0, &mu0_proto); +// double lambda0 = 0.2; +// double nu0 = 5.0; +// Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; +// bayesmix::Matrix tau0_proto; +// bayesmix::to_proto(tau0, &tau0_proto); +// *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; +// prior.mutable_fixed_values()->set_var_scaling(lambda0); +// prior.mutable_fixed_values()->set_deg_free(nu0); +// *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; +// hier->get_mutable_prior()->CopyFrom(prior); +// hier->initialize(); + +// Eigen::RowVectorXd datum(2); +// datum << 4.5, 4.5; + +// auto hier2 = hier->clone(); +// hier2->add_datum(0, datum, false); +// hier2->sample_full_cond(); + +// bayesmix::AlgorithmState out; +// bayesmix::AlgorithmState::ClusterState* clusval = +// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 +// = out.add_cluster_states(); hier->write_state_to_proto(clusval); +// hier2->write_state_to_proto(clusval2); + +// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +// } + +// TEST(lin_reg_uni_hierarchy, state_read_write) { +// Eigen::Vector2d beta; +// beta << 2, -1; +// double sigma2 = 9; + +// bayesmix::LinRegUniLSState ls; +// bayesmix::to_proto(beta, ls.mutable_regression_coeffs()); +// ls.set_var(sigma2); + +// bayesmix::AlgorithmState::ClusterState state; +// state.mutable_lin_reg_uni_ls_state()->CopyFrom(ls); + +// LinRegUniHierarchy hier; +// hier.set_state_from_proto(state); + +// ASSERT_EQ(hier.get_state().regression_coeffs, beta); +// ASSERT_EQ(hier.get_state().var, sigma2); + +// bayesmix::AlgorithmState outt; +// bayesmix::AlgorithmState::ClusterState* out = outt.add_cluster_states(); +// hier.write_state_to_proto(out); +// ASSERT_EQ(beta, bayesmix::to_eigen( +// out->lin_reg_uni_ls_state().regression_coeffs())); +// ASSERT_EQ(sigma2, out->lin_reg_uni_ls_state().var()); +// } + +// TEST(lin_reg_uni_hierarchy, misc) { +// // Build data +// int n = 5; +// int dim = 2; +// Eigen::Vector2d beta_true; +// beta_true << 10.0, 10.0; +// Eigen::MatrixXd cov = Eigen::MatrixXd::Random(n, dim); // each in U[-1,1] +// double sigma2 = 1.0; +// Eigen::VectorXd data(n); +// auto& rng = bayesmix::Rng::Instance().get(); +// for (int i = 0; i < n; i++) { +// data(i) = stan::math::normal_rng(cov.row(i).dot(beta_true), sigma2, +// rng); +// } +// // Initialize objects +// LinRegUniHierarchy hier; +// bayesmix::LinRegUniPrior prior; +// // Create prior parameters +// Eigen::Vector2d beta0 = 0 * beta_true; +// bayesmix::Vector beta0_proto; +// bayesmix::to_proto(beta0, &beta0_proto); +// auto Lambda0 = Eigen::Matrix2d::Identity(); +// bayesmix::Matrix Lambda0_proto; +// bayesmix::to_proto(Lambda0, &Lambda0_proto); +// double a0 = 2.0; +// double b0 = 1.0; +// // Initialize hierarchy +// *prior.mutable_fixed_values()->mutable_mean() = beta0_proto; +// *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; +// prior.mutable_fixed_values()->set_shape(a0); +// prior.mutable_fixed_values()->set_scale(b0); +// hier.get_mutable_prior()->CopyFrom(prior); +// hier.initialize(); +// // Extract hypers for reading test +// bayesmix::AlgorithmState::HierarchyHypers out; +// hier.write_hypers_to_proto(&out); +// ASSERT_EQ(beta0, bayesmix::to_eigen(out.lin_reg_uni_state().mean())); +// ASSERT_EQ(Lambda0, +// bayesmix::to_eigen(out.lin_reg_uni_state().var_scaling())); +// ASSERT_EQ(a0, out.lin_reg_uni_state().shape()); +// ASSERT_EQ(b0, out.lin_reg_uni_state().scale()); +// // Add data +// for (int i = 0; i < n; i++) { +// hier.add_datum(i, data.row(i), false, cov.row(i)); +// } +// // Check summary statistics +// // for (int i = 0; i < dim; i++) { +// // for (int j = 0; j < dim; j++) { +// // ASSERT_DOUBLE_EQ(hier.get_covar_sum_squares()(i, j), +// // (cov.transpose() * cov)(i, j)); +// // } +// // ASSERT_DOUBLE_EQ(hier.get_mixed_prod()(i), (cov.transpose() * +// data)(i)); +// // } +// // Compute and check posterior values +// hier.sample_full_cond(); +// auto state = hier.get_state(); +// for (int i = 0; i < dim; i++) { +// ASSERT_GT(state.regression_coeffs(i), beta0(i)); +// } +// } + +// TEST(nnxighierarchy, draw) { +// auto hier = std::make_shared(); +// bayesmix::NNxIGPrior prior; +// double mu0 = 5.0; +// double var0 = 1.0; +// double alpha0 = 2.0; +// double beta0 = 2.0; +// prior.mutable_fixed_values()->set_mean(mu0); +// prior.mutable_fixed_values()->set_var(var0); +// prior.mutable_fixed_values()->set_shape(alpha0); +// prior.mutable_fixed_values()->set_scale(beta0); +// hier->get_mutable_prior()->CopyFrom(prior); +// hier->initialize(); + +// auto hier2 = hier->clone(); +// hier2->sample_prior(); + +// bayesmix::AlgorithmState out; +// bayesmix::AlgorithmState::ClusterState* clusval = +// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 +// = out.add_cluster_states(); hier->write_state_to_proto(clusval); +// hier2->write_state_to_proto(clusval2); + +// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +// } + +// TEST(nnxighierarchy, sample_given_data) { +// auto hier = std::make_shared(); +// bayesmix::NNxIGPrior prior; +// double mu0 = 5.0; +// double var0 = 1.0; +// double alpha0 = 2.0; +// double beta0 = 2.0; +// prior.mutable_fixed_values()->set_mean(mu0); +// prior.mutable_fixed_values()->set_var(var0); +// prior.mutable_fixed_values()->set_shape(alpha0); +// prior.mutable_fixed_values()->set_scale(beta0); +// hier->get_mutable_prior()->CopyFrom(prior); + +// hier->initialize(); + +// Eigen::VectorXd datum(1); +// datum << 4.5; + +// auto hier2 = hier->clone(); +// hier2->add_datum(0, datum, false); +// hier2->sample_full_cond(); + +// bayesmix::AlgorithmState out; +// bayesmix::AlgorithmState::ClusterState* clusval = +// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 +// = out.add_cluster_states(); hier->write_state_to_proto(clusval); +// hier2->write_state_to_proto(clusval2); + +// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +// } diff --git a/test/prior_models.cc b/test/prior_models.cc index 0de9d9a8a..3c55ebd70 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -105,3 +105,23 @@ TEST(nig_prior_model, normal_mean_prior) { // Check ASSERT_GT(mean_out, mu00); } + +TEST(nig_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + hypers_proto.mutable_nnig_state()->set_mean(5.0); + hypers_proto.mutable_nnig_state()->set_var_scaling(0.1); + hypers_proto.mutable_nnig_state()->set_shape(4.0); + hypers_proto.mutable_nnig_state()->set_scale(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(false); + auto state2 = prior->sample(false); + + // Check if they coincides + ASSERT_TRUE(state1->DebugString() != state2->DebugString()); +} From e6f1ca435eec90028db7d30954fba15acb742708 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:26:20 +0100 Subject: [PATCH 044/317] Add first trivial updater class --- src/hierarchies/updaters/CMakeLists.txt | 5 ++ src/hierarchies/updaters/nnig_updater.cc | 64 ++++++++++++++++++++++++ src/hierarchies/updaters/nnig_updater.h | 20 ++++++++ 3 files changed, 89 insertions(+) create mode 100644 src/hierarchies/updaters/CMakeLists.txt create mode 100644 src/hierarchies/updaters/nnig_updater.cc create mode 100644 src/hierarchies/updaters/nnig_updater.h diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt new file mode 100644 index 000000000..bb6da2924 --- /dev/null +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources(bayesmix + PUBLIC + nnig_updater.h + nnig_updater.cc +) diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc new file mode 100644 index 000000000..cf4c0852f --- /dev/null +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -0,0 +1,64 @@ +#include "nnig_updater.h" + +std::shared_ptr NNIGUpdater::clone() const { + auto out = + std::make_shared(static_cast(*this)); + return out; +}; + +void NNIGUpdater::initialize(UniNormLikelihood &like, NIGPriorModel &prior) { + // PriorModel Initialization + prior.initialize(); + Hyperparams::NIG hypers = prior.get_hypers(); + prior.set_posterior_hypers(hypers); + + // State initialization + State::UniLS state; + state.mean = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + + // Likelihood Initalization + like.set_state(state); + like.clear_data(); + like.clear_summary_statistics(); +}; + +void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, + NIGPriorModel &prior) { + // Getting required quantities from likelihood and prior + int card = like.get_card(); + double data_sum = like.get_data_sum(); + double data_sum_squares = like.get_data_sum_squares(); + auto hypers = std::make_shared(prior.get_hypers()); + + // No update possible + if (card == 0) { + prior.set_posterior_hypers(*hypers); + return; + } + + // Compute posterior hyperparameters + Hyperparams::NIG post_params; + double y_bar = data_sum / (1.0 * card); // sample mean + double ss = data_sum_squares - card * y_bar * y_bar; + post_params.mean = (hypers->var_scaling * hypers->mean + data_sum) / + (hypers->var_scaling + card); + post_params.var_scaling = hypers->var_scaling + card; + post_params.shape = hypers->shape + 0.5 * card; + post_params.scale = hypers->scale + 0.5 * ss + + 0.5 * hypers->var_scaling * card * + (y_bar - hypers->mean) * (y_bar - hypers->mean) / + (card + hypers->var_scaling); + + prior.set_posterior_hypers(post_params); + return; +}; + +void NNIGUpdater::draw(UniNormLikelihood &like, NIGPriorModel &prior) { + if (like.get_card() == 0) { + like.set_state_from_proto(*prior.sample(true)); + } else { + compute_posterior_hypers(like, prior); + like.set_state_from_proto(*prior.sample(true)); + } +}; diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h new file mode 100644 index 000000000..074f82f3e --- /dev/null +++ b/src/hierarchies/updaters/nnig_updater.h @@ -0,0 +1,20 @@ +#ifndef BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ + +#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/likelihoods/uni_norm_likelihood.h" +#include "src/hierarchies/priors/hyperparams.h" +#include "src/hierarchies/priors/nig_prior_model.h" + +class NNIGUpdater { + public: + NNIGUpdater() = default; + ~NNIGUpdater() = default; + + std::shared_ptr clone() const; + void draw(UniNormLikelihood& like, NIGPriorModel& prior); + void initialize(UniNormLikelihood& like, NIGPriorModel& prior); + void compute_posterior_hypers(UniNormLikelihood& like, NIGPriorModel& prior); +}; + +#endif // BAYESMIX_HIERARCHIES_NNIG_UPDATERS_H_ From 121b6328bd813b9a29435d9fe7ee210d7b00f526 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:28:24 +0100 Subject: [PATCH 045/317] Revert previous changes --- src/hierarchies/likelihoods/states.h | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states.h index e9d7b7327..70dc7f032 100644 --- a/src/hierarchies/likelihoods/states.h +++ b/src/hierarchies/likelihoods/states.h @@ -5,25 +5,17 @@ namespace State { -class Base { - protected: - Base() = default; - - public: - virtual ~Base() = default; -}; - -struct UniLS : public Base { +struct UniLS { double mean, var; }; -struct MultiLS : public Base { +struct MultiLS { Eigen::VectorXd mean; Eigen::MatrixXd prec, prec_chol; double prec_logdet; }; -struct UniLinReg : public Base { +struct UniLinReg { Eigen::VectorXd regression_coeffs; double var; }; From f51cdea311829cc390d37a65ecdb98f338ce0bd6 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:29:27 +0100 Subject: [PATCH 046/317] Removed not yet ported hierarchies --- src/hierarchies/CMakeLists.txt | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index b354649d0..234a44f6d 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -2,16 +2,17 @@ target_sources(bayesmix PUBLIC abstract_hierarchy.h base_hierarchy.h - conjugate_hierarchy.h - lin_reg_uni_hierarchy.h - lin_reg_uni_hierarchy.cc + # conjugate_hierarchy.h + # lin_reg_uni_hierarchy.h + # lin_reg_uni_hierarchy.cc nnig_hierarchy.h - nnig_hierarchy.cc - nnw_hierarchy.h - nnw_hierarchy.cc - nnxig_hierarchy.h - nnxig_hierarchy.cc + # nnig_hierarchy.cc + # nnw_hierarchy.h + # nnw_hierarchy.cc + # nnxig_hierarchy.h + # nnxig_hierarchy.cc ) add_subdirectory(likelihoods) add_subdirectory(priors) +add_subdirectory(updaters) From 0bd86f017ee33b191226b4c407d1a541bb002126 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:34:29 +0100 Subject: [PATCH 047/317] NNIG Hierarchy partially ported to composition pattern (ONGOING) --- src/hierarchies/abstract_hierarchy.h | 4 +- src/hierarchies/base_hierarchy.h | 610 ++++++++++++++++----------- src/hierarchies/nnig_hierarchy.h | 119 +----- 3 files changed, 373 insertions(+), 360 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index ea5a47cc1..9b54c1fa7 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -198,10 +198,10 @@ class AbstractHierarchy { virtual bool is_multivariate() const = 0; //! Returns whether the hierarchy depends on covariate values or not - virtual bool is_dependent() const { return false; } + virtual bool is_dependent() const = 0; //! Returns whether the hierarchy represents a conjugate model or not - virtual bool is_conjugate() const { return false; } + virtual bool is_conjugate() const = 0; protected: //! Evaluates the log-likelihood of data in a single point diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index e97c92080..4e4407795 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -30,292 +30,392 @@ //! @tparam Hyperparams Class name of the container for hyperprior parameters //! @tparam Prior Class name of the container for prior parameters -template +template class BaseHierarchy : public AbstractHierarchy { + protected: + std::shared_ptr like = std::make_shared(); + std::shared_ptr prior = std::make_shared(); + std::shared_ptr updater = std::make_shared(); + public: BaseHierarchy() = default; ~BaseHierarchy() = default; - //! Returns an independent, data-less copy of this object - virtual std::shared_ptr clone() const override { + void set_likelihood(std::shared_ptr like_) { like = like_; }; + void set_prior(std::shared_ptr prior_) { prior = prior_; }; + void set_updater(std::shared_ptr updater_) { updater = updater_; }; + + std::shared_ptr clone() const override { + // Create copy of the hierarchy auto out = std::make_shared(static_cast(*this)); - out->clear_data(); - out->clear_summary_statistics(); + // Cloning each component class + out->set_likelihood(std::static_pointer_cast(like->clone())); + out->set_prior(std::static_pointer_cast(prior->clone())); + out->set_updater(std::static_pointer_cast(updater->clone())); return out; - } + }; - //! Evaluates the log-likelihood of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - virtual Eigen::VectorXd like_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; + Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = + Eigen::MatrixXd(0, 0)) const override { + return like->lpdf_grid(data, covariates); + }; - //! Generates new state values from the centering prior distribution void sample_prior() override { - state = static_cast(this)->draw(*hypers); - } + like->set_state_from_proto(*prior->sample(false)); + }; - //! Overloaded version of sample_full_cond(bool), mainly used for debugging - virtual void sample_full_cond( + void sample_full_cond(bool update_params = false) override { + updater->draw(*like, *prior); + }; + + // DA IMPLEMENTARE !!! + void sample_full_cond( const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override; + const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override { + return; + }; - //! Returns the current cardinality of the cluster - int get_card() const override { return card; } + void update_hypers(const std::vector + &states) override { + prior->update_hypers(states); + }; - //! Returns the logarithm of the current cardinality of the cluster - double get_log_card() const override { return log_card; } + int get_card() const override { return like->get_card(); }; - //! Returns the indexes of data points belonging to this cluster - std::set get_data_idx() const override { return cluster_data_idx; } + double get_log_card() const override { return like->get_log_card(); }; - //! Returns a pointer to the Protobuf message of the prior of this cluster - virtual google::protobuf::Message *get_mutable_prior() override { - if (prior == nullptr) { - create_empty_prior(); - } - return prior.get(); - } + std::set get_data_idx() const override { return like->get_data_idx(); }; - //! Writes current state to a Protobuf message by pointer - void write_state_to_proto(google::protobuf::Message *out) const override; + google::protobuf::Message *get_mutable_prior() { + return prior->get_mutable_prior(); + }; - //! Writes current values of the hyperparameters to a Protobuf message by - //! pointer - void write_hypers_to_proto(google::protobuf::Message *out) const override; + void write_state_to_proto(google::protobuf::Message *out) const override { + like->write_state_to_proto(out); + }; - //! Returns the struct of the current state - State get_state() const { return state; } + void write_hypers_to_proto(google::protobuf::Message *out) const override { + prior->write_hypers_to_proto(out); + }; - //! Returns the struct of the current prior hyperparameters - Hyperparams get_hypers() const { return *hypers; } + void set_state_from_proto(const google::protobuf::Message &state_) override { + like->set_state_from_proto(state_); + }; - //! Returns the struct of the current posterior hyperparameters - Hyperparams get_posterior_hypers() const { return posterior_hypers; } + void set_hypers_from_proto( + const google::protobuf::Message &state_) override { + prior->set_hypers_from_proto(state_); + }; - //! Adds a datum and its index to the hierarchy + // DA SISTEMARE void add_datum( const int id, const Eigen::RowVectorXd &datum, const bool update_params = false, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override { + // gestire update_params !! + like->add_datum(id, datum, covariate); + }; - //! Removes a datum and its index from the hierarchy + // DA SISTEMARE void remove_datum( const int id, const Eigen::RowVectorXd &datum, const bool update_params = false, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; - - //! Main function that initializes members to appropriate values - void initialize() override { - hypers = std::make_shared(); - check_prior_is_set(); - initialize_hypers(); - initialize_state(); - posterior_hypers = *hypers; - clear_data(); - clear_summary_statistics(); - } + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override { + // gestire update_params !! + like->remove_datum(id, datum, covariate); + }; - protected: - //! Raises an error if the prior pointer is not initialized - void check_prior_is_set() const { - if (prior == nullptr) { - throw std::invalid_argument("Hierarchy prior was not provided"); - } - } - - //! Re-initializes the prior of the hierarchy to a newly created object - void create_empty_prior() { prior.reset(new Prior); } - - //! Sets the cardinality of the cluster - void set_card(const int card_) { - card = card_; - log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); - } - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - virtual std::shared_ptr - get_state_proto() const = 0; - - //! Initializes state parameters to appropriate values - virtual void initialize_state() = 0; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - virtual std::shared_ptr - get_hypers_proto() const = 0; - - //! Initializes hierarchy hyperparameters to appropriate values - virtual void initialize_hypers() = 0; - - //! Resets cardinality and indexes of data in this cluster - void clear_data() { - set_card(0); - cluster_data_idx = std::set(); - } - - virtual void clear_summary_statistics() = 0; - - //! Down-casts the given generic proto message to a ClusterState proto - bayesmix::AlgorithmState::ClusterState *downcast_state( - google::protobuf::Message *state_) const { - return google::protobuf::internal::down_cast< - bayesmix::AlgorithmState::ClusterState *>(state_); - } - - //! Down-casts the given generic proto message to a ClusterState proto - const bayesmix::AlgorithmState::ClusterState &downcast_state( - const google::protobuf::Message &state_) const { - return google::protobuf::internal::down_cast< - const bayesmix::AlgorithmState::ClusterState &>(state_); - } - - //! Down-casts the given generic proto message to a HierarchyHypers proto - bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( - google::protobuf::Message *state_) const { - return google::protobuf::internal::down_cast< - bayesmix::AlgorithmState::HierarchyHypers *>(state_); - } - - //! Down-casts the given generic proto message to a HierarchyHypers proto - const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( - const google::protobuf::Message &state_) const { - return google::protobuf::internal::down_cast< - const bayesmix::AlgorithmState::HierarchyHypers &>(state_); - } - - //! Container for state values - State state; - - //! Container for prior hyperparameters values - std::shared_ptr hypers; - - //! Container for posterior hyperparameters values - Hyperparams posterior_hypers; - - //! Pointer to a Protobuf prior object for this class - std::shared_ptr prior; - - //! Set of indexes of data points belonging to this cluster - std::set cluster_data_idx; - - //! Current cardinality of this cluster - int card = 0; - - //! Logarithm of current cardinality of this cluster - double log_card = stan::math::NEGATIVE_INFTY; + void initialize() override { updater->initialize(*like, *prior); }; + + bool is_multivariate() const override { return like->is_multivariate(); }; + + bool is_dependent() const override { return like->is_dependent(); }; }; -template -void BaseHierarchy::add_datum( - const int id, const Eigen::RowVectorXd &datum, - const bool update_params /*= false*/, - const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) { - assert(cluster_data_idx.find(id) == cluster_data_idx.end()); - card += 1; - log_card = std::log(card); - static_cast(this)->update_ss(datum, covariate, true); - cluster_data_idx.insert(id); - if (update_params) { - static_cast(this)->save_posterior_hypers(); - } -} - -template -void BaseHierarchy::remove_datum( - const int id, const Eigen::RowVectorXd &datum, - const bool update_params /*= false*/, - const Eigen::RowVectorXd &covariate /* = Eigen::RowVectorXd(0)*/) { - static_cast(this)->update_ss(datum, covariate, false); - set_card(card - 1); - auto it = cluster_data_idx.find(id); - assert(it != cluster_data_idx.end()); - cluster_data_idx.erase(it); - if (update_params) { - static_cast(this)->save_posterior_hypers(); - } -} - -template -void BaseHierarchy::write_state_to_proto( - google::protobuf::Message *out) const { - std::shared_ptr state_ = - get_state_proto(); - auto *out_cast = downcast_state(out); - out_cast->CopyFrom(*state_.get()); - out_cast->set_cardinality(card); -} - -template -void BaseHierarchy::write_hypers_to_proto( - google::protobuf::Message *out) const { - std::shared_ptr hypers_ = - get_hypers_proto(); - auto *out_cast = downcast_hypers(out); - out_cast->CopyFrom(*hypers_.get()); -} - -template -Eigen::VectorXd -BaseHierarchy::like_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), covariates.row(i)); - } - } - return lpdf; -} - -template -void BaseHierarchy::sample_full_cond( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { - clear_data(); - clear_summary_statistics(); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - covariates.row(i)); - } - } - static_cast(this)->sample_full_cond(true); -} +// //! Returns an independent, data-less copy of this object +// virtual std::shared_ptr clone() const override { +// auto out = std::make_shared(static_cast(*this)); out->clear_data(); out->clear_summary_statistics(); return +// out; +// } + +// //! Evaluates the log-likelihood of data in a grid of points +// //! @param data Grid of points (by row) which are to be evaluated +// //! @param covariates (Optional) covariate vectors associated to data +// //! @return The evaluation of the lpdf +// virtual Eigen::VectorXd like_lpdf_grid( +// const Eigen::MatrixXd &data, +// const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, +// 0)) const +// override; + +// //! Generates new state values from the centering prior distribution +// void sample_prior() override { +// state = static_cast(this)->draw(*hypers); +// } + +// //! Overloaded version of sample_full_cond(bool), mainly used for +// debugging virtual void sample_full_cond( +// const Eigen::MatrixXd &data, +// const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override; + +// //! Returns the current cardinality of the cluster +// int get_card() const override { return card; } + +// //! Returns the logarithm of the current cardinality of the cluster +// double get_log_card() const override { return log_card; } + +// //! Returns the indexes of data points belonging to this cluster +// std::set get_data_idx() const override { return cluster_data_idx; } + +// //! Returns a pointer to the Protobuf message of the prior of this cluster +// virtual google::protobuf::Message *get_mutable_prior() override { +// if (prior == nullptr) { +// create_empty_prior(); +// } +// return prior.get(); +// } + +// //! Writes current state to a Protobuf message by pointer +// void write_state_to_proto(google::protobuf::Message *out) const override; + +// //! Writes current values of the hyperparameters to a Protobuf message by +// //! pointer +// void write_hypers_to_proto(google::protobuf::Message *out) const override; + +// //! Returns the struct of the current state +// State get_state() const { return state; } + +// //! Returns the struct of the current prior hyperparameters +// Hyperparams get_hypers() const { return *hypers; } + +// //! Returns the struct of the current posterior hyperparameters +// Hyperparams get_posterior_hypers() const { return posterior_hypers; } + +// //! Adds a datum and its index to the hierarchy +// void add_datum( +// const int id, const Eigen::RowVectorXd &datum, +// const bool update_params = false, +// const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + +// //! Removes a datum and its index from the hierarchy +// void remove_datum( +// const int id, const Eigen::RowVectorXd &datum, +// const bool update_params = false, +// const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + +// //! Main function that initializes members to appropriate values +// void initialize() override { +// hypers = std::make_shared(); +// check_prior_is_set(); +// initialize_hypers(); +// initialize_state(); +// posterior_hypers = *hypers; +// clear_data(); +// clear_summary_statistics(); +// } + +// protected: +// //! Raises an error if the prior pointer is not initialized +// void check_prior_is_set() const { +// if (prior == nullptr) { +// throw std::invalid_argument("Hierarchy prior was not provided"); +// } +// } + +// //! Re-initializes the prior of the hierarchy to a newly created object +// void create_empty_prior() { prior.reset(new Prior); } + +// //! Sets the cardinality of the cluster +// void set_card(const int card_) { +// card = card_; +// log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); +// } + +// //! Writes current state to a Protobuf message and return a shared_ptr +// //! New hierarchies have to first modify the field 'oneof val' in the +// //! AlgoritmState::ClusterState message by adding the appropriate type +// virtual std::shared_ptr +// get_state_proto() const = 0; + +// //! Initializes state parameters to appropriate values +// virtual void initialize_state() = 0; + +// //! Writes current value of hyperparameters to a Protobuf message and +// //! return a shared_ptr. +// //! New hierarchies have to first modify the field 'oneof val' in the +// //! AlgoritmState::HierarchyHypers message by adding the appropriate type +// virtual std::shared_ptr +// get_hypers_proto() const = 0; + +// //! Initializes hierarchy hyperparameters to appropriate values +// virtual void initialize_hypers() = 0; + +// //! Resets cardinality and indexes of data in this cluster +// void clear_data() { +// set_card(0); +// cluster_data_idx = std::set(); +// } + +// virtual void clear_summary_statistics() = 0; + +// //! Down-casts the given generic proto message to a ClusterState proto +// bayesmix::AlgorithmState::ClusterState *downcast_state( +// google::protobuf::Message *state_) const { +// return google::protobuf::internal::down_cast< +// bayesmix::AlgorithmState::ClusterState *>(state_); +// } + +// //! Down-casts the given generic proto message to a ClusterState proto +// const bayesmix::AlgorithmState::ClusterState &downcast_state( +// const google::protobuf::Message &state_) const { +// return google::protobuf::internal::down_cast< +// const bayesmix::AlgorithmState::ClusterState &>(state_); +// } + +// //! Down-casts the given generic proto message to a HierarchyHypers proto +// bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( +// google::protobuf::Message *state_) const { +// return google::protobuf::internal::down_cast< +// bayesmix::AlgorithmState::HierarchyHypers *>(state_); +// } + +// //! Down-casts the given generic proto message to a HierarchyHypers proto +// const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( +// const google::protobuf::Message &state_) const { +// return google::protobuf::internal::down_cast< +// const bayesmix::AlgorithmState::HierarchyHypers &>(state_); +// } + +// //! Container for state values +// State state; + +// //! Container for prior hyperparameters values +// std::shared_ptr hypers; + +// //! Container for posterior hyperparameters values +// Hyperparams posterior_hypers; + +// //! Pointer to a Protobuf prior object for this class +// std::shared_ptr prior; + +// //! Set of indexes of data points belonging to this cluster +// std::set cluster_data_idx; + +// //! Current cardinality of this cluster +// int card = 0; + +// //! Logarithm of current cardinality of this cluster +// double log_card = stan::math::NEGATIVE_INFTY; +// }; + +// template void BaseHierarchy::add_datum( +// const int id, const Eigen::RowVectorXd &datum, +// const bool update_params /*= false*/, +// const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) { +// assert(cluster_data_idx.find(id) == cluster_data_idx.end()); +// card += 1; +// log_card = std::log(card); +// static_cast(this)->update_ss(datum, covariate, true); +// cluster_data_idx.insert(id); +// if (update_params) { +// static_cast(this)->save_posterior_hypers(); +// } +// } + +// template void BaseHierarchy::remove_datum( +// const int id, const Eigen::RowVectorXd &datum, +// const bool update_params /*= false*/, +// const Eigen::RowVectorXd &covariate /* = Eigen::RowVectorXd(0)*/) { +// static_cast(this)->update_ss(datum, covariate, false); +// set_card(card - 1); +// auto it = cluster_data_idx.find(id); +// assert(it != cluster_data_idx.end()); +// cluster_data_idx.erase(it); +// if (update_params) { +// static_cast(this)->save_posterior_hypers(); +// } +// } + +// template void BaseHierarchy::write_state_to_proto( +// google::protobuf::Message *out) const { +// std::shared_ptr state_ = +// get_state_proto(); +// auto *out_cast = downcast_state(out); +// out_cast->CopyFrom(*state_.get()); +// out_cast->set_cardinality(card); +// } + +// template void BaseHierarchy::write_hypers_to_proto( +// google::protobuf::Message *out) const { +// std::shared_ptr hypers_ = +// get_hypers_proto(); +// auto *out_cast = downcast_hypers(out); +// out_cast->CopyFrom(*hypers_.get()); +// } + +// template Eigen::VectorXd BaseHierarchy::like_lpdf_grid( +// const Eigen::MatrixXd &data, +// const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { +// Eigen::VectorXd lpdf(data.rows()); +// if (covariates.cols() == 0) { +// // Pass null value as covariate +// for (int i = 0; i < data.rows(); i++) { +// lpdf(i) = static_cast(this)->get_like_lpdf( +// data.row(i), Eigen::RowVectorXd(0)); +// } +// } else if (covariates.rows() == 1) { +// // Use unique covariate +// for (int i = 0; i < data.rows(); i++) { +// lpdf(i) = static_cast(this)->get_like_lpdf( +// data.row(i), covariates.row(0)); +// } +// } else { +// // Use different covariates +// for (int i = 0; i < data.rows(); i++) { +// lpdf(i) = static_cast(this)->get_like_lpdf( +// data.row(i), covariates.row(i)); +// } +// } +// return lpdf; +// } + +// template void BaseHierarchy::sample_full_cond( +// const Eigen::MatrixXd &data, +// const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { +// clear_data(); +// clear_summary_statistics(); +// if (covariates.cols() == 0) { +// // Pass null value as covariate +// for (int i = 0; i < data.rows(); i++) { +// static_cast(this)->add_datum(i, data.row(i), false, +// Eigen::RowVectorXd(0)); +// } +// } else if (covariates.rows() == 1) { +// // Use unique covariate +// for (int i = 0; i < data.rows(); i++) { +// static_cast(this)->add_datum(i, data.row(i), false, +// covariates.row(0)); +// } +// } else { +// // Use different covariates +// for (int i = 0; i < data.rows(); i++) { +// static_cast(this)->add_datum(i, data.row(i), false, +// covariates.row(i)); +// } +// } +// static_cast(this)->sample_full_cond(true); +// } #endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index 7911691e9..1271159fe 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -1,122 +1,35 @@ #ifndef BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ -#include +// #include -#include -#include -#include +// #include +// #include +// #include -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" +// #include "algorithm_state.pb.h" +// #include "conjugate_hierarchy.h" #include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" +// #include "hierarchy_prior.pb.h" -//! Conjugate Normal Normal-InverseGamma hierarchy for univariate data. +#include "base_hierarchy.h" +#include "likelihoods/uni_norm_likelihood.h" +#include "priors/nig_prior_model.h" +#include "updaters/nnig_updater.h" -//! This class represents a hierarchical model where data are distributed -//! according to a normal likelihood, the parameters of which have a -//! Normal-InverseGamma centering distribution. That is: -//! f(x_i|mu,sig) = N(mu,sig^2) -//! (mu,sig^2) ~ N-IG(mu0, lambda0, alpha0, beta0) -//! The state is composed of mean and variance. The state hyperparameters, -//! contained in the Hypers object, are (mu_0, lambda0, alpha0, beta0), all -//! scalar values. Note that this hierarchy is conjugate, thus the marginal -//! distribution is available in closed form. For more information, please -//! refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace NNIG { -//! Custom container for State values -struct State { - double mean, var; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double mean, var_scaling, shape, scale; -}; - -}; // namespace NNIG - -class NNIGHierarchy - : public ConjugateHierarchy { +class NNIGHierarchy : public BaseHierarchy { public: NNIGHierarchy() = default; ~NNIGHierarchy() = default; - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - NNIG::State draw(const NNIG::Hyperparams ¶ms); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNIG; } - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Computes and return posterior hypers given data currently in this cluster - NNIG::Hyperparams compute_posterior_hypers() const; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double marg_lpdf(const NNIG::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) override; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Sum of data points currently belonging to the cluster - double data_sum = 0; + bool is_conjugate() const override { return true; } - //! Sum of squared data points currently belonging to the cluster - double data_sum_squares = 0; + // MANCANO LE PREDICTIVE LPDFS (DOVE LE METTIAMO)? }; -#endif // BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ +#endif From 024e274588dc12094ce4879215b97c91e365c7f7 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:35:02 +0100 Subject: [PATCH 048/317] Improved API and bug fixes --- src/hierarchies/likelihoods/abstract_likelihood.h | 3 +++ src/hierarchies/likelihoods/base_likelihood.h | 12 +++++++----- src/hierarchies/likelihoods/uni_norm_likelihood.h | 2 ++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index eef452965..8a7a8841c 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -30,6 +30,9 @@ class AbstractLikelihood { } } + // AGGIUNGERE CLUST_LPDF (CHE VALUTA LA LIKELIHOOD CONGIUNTA SU TUTTO IL + // CLUSTER) + virtual Eigen::VectorXd lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const = 0; diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 9f35b82cd..90a32b584 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -39,6 +39,8 @@ class BaseLikelihood : public AbstractLikelihood { State get_state() const { return state; } + void set_state(const State &_state) { state = _state; }; + void add_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; @@ -47,17 +49,17 @@ class BaseLikelihood : public AbstractLikelihood { const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + void clear_data() { + set_card(0); + cluster_data_idx = std::set(); + } + protected: void set_card(const int card_) { card = card_; log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); } - void clear_data() { - set_card(0); - cluster_data_idx = std::set(); - } - bayesmix::AlgorithmState::ClusterState *downcast_state( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 5607d8b62..e36f9d908 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -21,6 +21,8 @@ class UniNormLikelihood bool is_dependent() const override { return false; }; void set_state_from_proto(const google::protobuf::Message &state_) override; void clear_summary_statistics() override; + double get_data_sum() const { return data_sum; }; + double get_data_sum_squares() const { return data_sum_squares; }; protected: std::shared_ptr get_state_proto() From af886f1f2c299e3b97a6512c43ead5842c99aeb2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:37:58 +0100 Subject: [PATCH 049/317] Improved API and bug fixes --- src/hierarchies/priors/abstract_prior_model.h | 6 ++- src/hierarchies/priors/base_prior_model.h | 11 +++++ src/hierarchies/priors/nig_prior_model.cc | 49 ++++++++----------- src/hierarchies/priors/nig_prior_model.h | 5 +- 4 files changed, 39 insertions(+), 32 deletions(-) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 1d44b6a9c..f3d4e2702 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -8,6 +8,7 @@ #include #include "algorithm_state.pb.h" +#include "src/hierarchies/likelihoods/states.h" #include "src/utils/rng.h" class AbstractPriorModel { @@ -17,10 +18,11 @@ class AbstractPriorModel { // IMPLEMENTED in BasePriorModel virtual std::shared_ptr clone() const = 0; - virtual double lpdf() = 0; + virtual double lpdf(const google::protobuf::Message &state_) = 0; // Da pensare, come restituisco lo stato? magari un pointer? Oppure delego - // all'updater?? virtual void sample() = 0; + virtual std::shared_ptr sample( + bool use_post_hypers = false) = 0; virtual void update_hypers( const std::vector &states) = 0; diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index a3a46728b..a1c0aa980 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -27,6 +27,10 @@ class BasePriorModel : public AbstractPriorModel { HyperParams get_hypers() const { return *hypers; } + void set_posterior_hypers(const HyperParams &_post_hypers) { + post_hypers = std::make_shared(_post_hypers); + }; + void write_hypers_to_proto(google::protobuf::Message *out) const override; void initialize(); @@ -48,7 +52,14 @@ class BasePriorModel : public AbstractPriorModel { const bayesmix::AlgorithmState::HierarchyHypers &>(state_); } + const bayesmix::AlgorithmState::ClusterState &downcast_state( + const google::protobuf::Message &state_) const { + return google::protobuf::internal::down_cast< + const bayesmix::AlgorithmState::ClusterState &>(state_); + } + std::shared_ptr hypers = std::make_shared(); + std::shared_ptr post_hypers = std::make_shared(); std::shared_ptr prior; }; diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index 875d247dd..6f9b6cbb5 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -75,38 +75,29 @@ void NIGPriorModel::initialize_hypers() { } } -double NIGPriorModel::lpdf() { - if (prior->has_fixed_values()) { - return 0; - } else if (prior->has_normal_mean_prior()) { - double mu = prior->normal_mean_prior().mean_prior().mean(); - double var = prior->normal_mean_prior().mean_prior().var(); - return stan::math::normal_lpdf(hypers->mean, mu, sqrt(var)); - } else if (prior->has_ngg_prior()) { - // Set variables - double mu, var, shape, rate; - double target = 0; - - // Gaussian distribution on the mean - mu = prior->ngg_prior().mean_prior().mean(); - var = prior->ngg_prior().mean_prior().var(); - target += stan::math::normal_lpdf(hypers->mean, mu, sqrt(var)); +double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { + auto &state = downcast_state(state_).uni_ls_state(); + double target = + stan::math::normal_lpdf(state.mean(), hypers->mean, + sqrt(state.var() / hypers->var_scaling)) + + stan::math::inv_gamma_lpdf(state.var(), hypers->shape, hypers->scale); + return target; +} - // Gamma distribution on var_scaling - shape = prior->ngg_prior().var_scaling_prior().shape(); - rate = prior->ngg_prior().var_scaling_prior().rate(); - target += stan::math::gamma_lpdf(hypers->var_scaling, shape, rate); +std::shared_ptr NIGPriorModel::sample( + bool use_post_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + Hyperparams::NIG params = use_post_hypers ? *post_hypers : *hypers; - // Gamma distribution on scale - shape = prior->ngg_prior().scale_prior().shape(); - rate = prior->ngg_prior().scale_prior().rate(); - target += stan::math::gamma_lpdf(hypers->var_scaling, shape, rate); + double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); + double mean = + stan::math::normal_rng(params.mean, sqrt(var / params.var_scaling), rng); - return target; - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} + bayesmix::AlgorithmState::ClusterState state; + state.mutable_uni_ls_state()->set_mean(mean); + state.mutable_uni_ls_state()->set_var(var); + return std::make_shared(state); +}; void NIGPriorModel::update_hypers( const std::vector &states) { diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index ca359980b..bf92a6825 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -20,7 +20,10 @@ class NIGPriorModel : public BasePriorModel sample( + bool use_post_hypers = false) override; void update_hypers(const std::vector &states) override; From 31a0c36c61ce4b4e7dd045a275ee318a9519854d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 19 Jan 2022 22:38:23 +0100 Subject: [PATCH 050/317] Moved old hierarchy pattern to .old --- src/hierarchies/conjugate_hierarchy.h | 206 ------------- src/hierarchies/lin_reg_uni_hierarchy.cc | 166 ---------- src/hierarchies/lin_reg_uni_hierarchy.h | 144 --------- src/hierarchies/nnig_hierarchy.cc | 267 ---------------- src/hierarchies/nnw_hierarchy.cc | 373 ----------------------- src/hierarchies/nnw_hierarchy.h | 168 ---------- src/hierarchies/nnxig_hierarchy.cc | 152 --------- src/hierarchies/nnxig_hierarchy.h | 120 -------- 8 files changed, 1596 deletions(-) delete mode 100644 src/hierarchies/conjugate_hierarchy.h delete mode 100644 src/hierarchies/lin_reg_uni_hierarchy.cc delete mode 100644 src/hierarchies/lin_reg_uni_hierarchy.h delete mode 100644 src/hierarchies/nnig_hierarchy.cc delete mode 100644 src/hierarchies/nnw_hierarchy.cc delete mode 100644 src/hierarchies/nnw_hierarchy.h delete mode 100644 src/hierarchies/nnxig_hierarchy.cc delete mode 100644 src/hierarchies/nnxig_hierarchy.h diff --git a/src/hierarchies/conjugate_hierarchy.h b/src/hierarchies/conjugate_hierarchy.h deleted file mode 100644 index 3a7350c98..000000000 --- a/src/hierarchies/conjugate_hierarchy.h +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ - -#include "base_hierarchy.h" - -//! Template base class for conjugate hierarchy objects. - -//! This class acts as the base class for conjugate models, i.e. ones for which -//! both the prior and posterior distribution have the same form -//! (non-conjugate hierarchies should instead inherit directly from -//! `BaseHierarchy`). This also means that the marginal distribution for the -//! data is available in closed form. For this reason, each class deriving from -//! this one must have a free method with one of the following signatures, -//! based on whether it depends on covariates or not: -//! double marg_lpdf( -//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, -//! const Eigen::RowVectorXd &covariate) const; -//! or -//! double marg_lpdf( -//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum) const; -//! This returns the evaluation of the marginal distribution on the given data -//! point (and covariate, if any), conditioned on the provided `Hyperparams` -//! object. The latter may contain either prior or posterior values for -//! hyperparameters, depending on where this function is called within the -//! library. -//! For more information, please refer to parent classes `AbstractHierarchy` -//! and `BaseHierarchy`. - -template -class ConjugateHierarchy - : public BaseHierarchy { - public: - using BaseHierarchy::hypers; - using BaseHierarchy::posterior_hypers; - using BaseHierarchy::state; - - ConjugateHierarchy() = default; - ~ConjugateHierarchy() = default; - - //! Public wrapper for `marg_lpdf()` methods - virtual double get_marg_lpdf( - const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const; - - //! Evaluates the log-prior predictive distribution of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate (Optional) covariate vector associated to datum - //! @return The evaluation of the lpdf - double prior_pred_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = - Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(*hypers, datum, covariate); - } - - //! Evaluates the log-conditional predictive distr. of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate (Optional) covariate vector associated to datum - //! @return The evaluation of the lpdf - double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = - Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(posterior_hypers, datum, covariate); - } - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - virtual Eigen::VectorXd prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - virtual Eigen::VectorXd conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(bool update_params = true) override { - if (this->card == 0) { - // No posterior update possible - static_cast(this)->sample_prior(); - } else { - Hyperparams params = - update_params - ? static_cast(this)->compute_posterior_hypers() - : posterior_hypers; - state = static_cast(this)->draw(params); - } - } - - //! Saves posterior hyperparameters to the corresponding class member - void save_posterior_hypers() { - posterior_hypers = - static_cast(this)->compute_posterior_hypers(); - } - - //! Returns whether the hierarchy represents a conjugate model or not - bool is_conjugate() const override { return true; } - - protected: - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - virtual double marg_lpdf(const Hyperparams ¶ms, - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - if (!this->is_dependent()) { - throw std::runtime_error( - "Cannot call this function from a non-dependent hierarchy"); - } else { - throw std::runtime_error("Not implemented"); - } - } - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - virtual double marg_lpdf(const Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - if (this->is_dependent()) { - throw std::runtime_error( - "Cannot call this function from a dependent hierarchy"); - } else { - throw std::runtime_error("Not implemented"); - } - } -}; - -template -double ConjugateHierarchy::get_marg_lpdf( - const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { - if (this->is_dependent()) { - return marg_lpdf(params, datum, covariate); - } else { - return marg_lpdf(params, datum); - } -} - -template -Eigen::VectorXd -ConjugateHierarchy::prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), covariates.row(i)); - } - } - return lpdf; -} - -template -Eigen::VectorXd ConjugateHierarchy:: - conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), covariates.row(i)); - } - } - return lpdf; -} - -#endif diff --git a/src/hierarchies/lin_reg_uni_hierarchy.cc b/src/hierarchies/lin_reg_uni_hierarchy.cc deleted file mode 100644 index d48f6ada8..000000000 --- a/src/hierarchies/lin_reg_uni_hierarchy.cc +++ /dev/null @@ -1,166 +0,0 @@ -#include "lin_reg_uni_hierarchy.h" - -#include -#include -#include - -#include "src/utils/eigen_utils.h" -#include "src/utils/proto_utils.h" -#include "src/utils/rng.h" - -double LinRegUniHierarchy::like_lpdf( - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - return stan::math::normal_lpdf( - datum(0), state.regression_coeffs.dot(covariate), sqrt(state.var)); -} - -double LinRegUniHierarchy::marg_lpdf( - const LinRegUni::Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - double sig_n = sqrt( - (1 + (covariate * params.var_scaling_inv * covariate.transpose())(0)) * - params.scale / params.shape); - return stan::math::student_t_lpdf(datum(0), 2 * params.shape, - covariate.dot(params.mean), sig_n); -} - -void LinRegUniHierarchy::initialize_state() { - state.regression_coeffs = hypers->mean; - state.var = hypers->scale / (hypers->shape + 1); -} - -void LinRegUniHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); - dim = hypers->mean.size(); - hypers->var_scaling = - bayesmix::to_eigen(prior->fixed_values().var_scaling()); - hypers->var_scaling_inv = stan::math::inverse_spd(hypers->var_scaling); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - // Check validity - if (dim != hypers->var_scaling.rows()) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - bayesmix::check_spd(hypers->var_scaling); - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void LinRegUniHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -LinRegUni::State LinRegUniHierarchy::draw( - const LinRegUni::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - LinRegUni::State out; - out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - out.regression_coeffs = stan::math::multi_normal_prec_rng( - params.mean, params.var_scaling / out.var, rng); - return out; -} - -void LinRegUniHierarchy::update_summary_statistics( - const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate, - bool add) { - if (add) { - data_sum_squares += datum(0) * datum(0); - covar_sum_squares += covariate.transpose() * covariate; - mixed_prod += datum(0) * covariate.transpose(); - } else { - data_sum_squares -= datum(0) * datum(0); - covar_sum_squares -= covariate.transpose() * covariate; - mixed_prod -= datum(0) * covariate.transpose(); - } -} - -void LinRegUniHierarchy::clear_summary_statistics() { - mixed_prod = Eigen::VectorXd::Zero(dim); - data_sum_squares = 0.0; - covar_sum_squares = Eigen::MatrixXd::Zero(dim, dim); -} - -LinRegUni::Hyperparams LinRegUniHierarchy::compute_posterior_hypers() const { - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - LinRegUni::Hyperparams post_params; - post_params.var_scaling = covar_sum_squares + hypers->var_scaling; - auto llt = post_params.var_scaling.llt(); - post_params.var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, dim)); - post_params.mean = - llt.solve(mixed_prod + hypers->var_scaling * hypers->mean); - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = - hypers->scale + - 0.5 * (data_sum_squares + - hypers->mean.transpose() * hypers->var_scaling * hypers->mean - - post_params.mean.transpose() * post_params.var_scaling * - post_params.mean); - return post_params; -} - -void LinRegUniHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.regression_coeffs = - bayesmix::to_eigen(statecast.lin_reg_uni_ls_state().regression_coeffs()); - state.var = statecast.lin_reg_uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -LinRegUniHierarchy::get_state_proto() const { - bayesmix::LinRegUniLSState state_; - bayesmix::to_proto(state.regression_coeffs, - state_.mutable_regression_coeffs()); - state_.set_var(state.var); - - auto out = std::make_shared(); - out->mutable_lin_reg_uni_ls_state()->CopyFrom(state_); - return out; -} - -void LinRegUniHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).lin_reg_uni_state(); - hypers->mean = bayesmix::to_eigen(hyperscast.mean()); - hypers->var_scaling = bayesmix::to_eigen(hyperscast.var_scaling()); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); -} - -std::shared_ptr -LinRegUniHierarchy::get_hypers_proto() const { - bayesmix::MultiNormalIGDistribution hypers_; - bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); - bayesmix::to_proto(hypers->var_scaling, hypers_.mutable_var_scaling()); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - - auto out = std::make_shared(); - out->mutable_lin_reg_uni_state()->CopyFrom(hypers_); - return out; -} diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h deleted file mode 100644 index 7789a6065..000000000 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ /dev/null @@ -1,144 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Linear regression hierarchy for univariate data. - -//! This class implements a dependent hierarchy which represents the classical -//! univariate Bayesian linear regression model, i.e.: -//! y_i | \beta, x_i, \sigma^2 \sim N(\beta^T x_i, sigma^2) -//! \beta | \sigma^2 \sim N(\mu, sigma^2 Lambda^{-1}) -//! \sigma^2 \sim InvGamma(a, b) -//! -//! The state consists of the `regression_coeffs` \beta, and the `var` sigma^2. -//! Lambda is called the variance-scaling factor. For more information, please -//! refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace LinRegUni { -//! Custom container for State values -struct State { - Eigen::VectorXd regression_coeffs; - double var; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - Eigen::VectorXd mean; - Eigen::MatrixXd var_scaling; - Eigen::MatrixXd var_scaling_inv; - double shape; - double scale; -}; -} // namespace LinRegUni - -class LinRegUniHierarchy - : public ConjugateHierarchy { - public: - LinRegUniHierarchy() = default; - ~LinRegUniHierarchy() = default; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - LinRegUni::State draw(const LinRegUni::Hyperparams ¶ms); - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param covariate Covariate vector associated to datum - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate, - bool add) override; - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::LinRegUni; - } - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Returns the dimension of the coefficients vector - unsigned int get_dim() const { return dim; } - - //! Computes and return posterior hypers given data currently in this cluster - LinRegUni::Hyperparams compute_posterior_hypers() const; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - //! Returns whether the hierarchy depends on covariate values or not - bool is_dependent() const override { return true; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - double marg_lpdf(const LinRegUni::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const override; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Dimension of the coefficients vector - unsigned int dim; - - //! Represents pieces of y^t y - double data_sum_squares; - - //! Represents pieces of X^T X - Eigen::MatrixXd covar_sum_squares; - - //! Represents pieces of X^t y - Eigen::VectorXd mixed_prod; -}; - -#endif // BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ diff --git a/src/hierarchies/nnig_hierarchy.cc b/src/hierarchies/nnig_hierarchy.cc deleted file mode 100644 index c2b055178..000000000 --- a/src/hierarchies/nnig_hierarchy.cc +++ /dev/null @@ -1,267 +0,0 @@ -#include "nnig_hierarchy.h" - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/rng.h" - -double NNIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); -} - -double NNIGHierarchy::marg_lpdf(const NNIG::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - double sig_n = sqrt(params.scale * (params.var_scaling + 1) / - (params.shape * params.var_scaling)); - return stan::math::student_t_lpdf(datum(0), 2 * params.shape, params.mean, - sig_n); -} - -void NNIGHierarchy::initialize_state() { - state.mean = hypers->mean; - state.var = hypers->scale / (hypers->shape + 1); -} - -void NNIGHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = prior->fixed_values().mean(); - hypers->var_scaling = prior->fixed_values().var_scaling(); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - // Check validity - if (hypers->var_scaling <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } - - else if (prior->has_normal_mean_prior()) { - // Set initial values - hypers->mean = prior->normal_mean_prior().mean_prior().mean(); - hypers->var_scaling = prior->normal_mean_prior().var_scaling(); - hypers->shape = prior->normal_mean_prior().shape(); - hypers->scale = prior->normal_mean_prior().scale(); - // Check validity - if (hypers->var_scaling <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } - - else if (prior->has_ngg_prior()) { - // Get hyperparameters: - // for mu0 - double mu00 = prior->ngg_prior().mean_prior().mean(); - double sigma00 = prior->ngg_prior().mean_prior().var(); - // for lambda0 - double alpha00 = prior->ngg_prior().var_scaling_prior().shape(); - double beta00 = prior->ngg_prior().var_scaling_prior().rate(); - // for beta0 - double a00 = prior->ngg_prior().scale_prior().shape(); - double b00 = prior->ngg_prior().scale_prior().rate(); - // for alpha0 - double alpha0 = prior->ngg_prior().shape(); - // Check validity - if (sigma00 <= 0) { - throw std::invalid_argument("Variance parameter must be > 0"); - } - if (alpha00 <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (beta00 <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - if (a00 <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (b00 <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - if (alpha0 <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - // Set initial values - hypers->mean = mu00; - hypers->var_scaling = alpha00 / beta00; - hypers->shape = alpha0; - hypers->scale = a00 / b00; - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void NNIGHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - - if (prior->has_fixed_values()) { - return; - } - - else if (prior->has_normal_mean_prior()) { - // Get hyperparameters - double mu00 = prior->normal_mean_prior().mean_prior().mean(); - double sig200 = prior->normal_mean_prior().mean_prior().var(); - double lambda0 = prior->normal_mean_prior().var_scaling(); - // Compute posterior hyperparameters - double prec = 0.0; - double num = 0.0; - for (auto &st : states) { - double mean = st.uni_ls_state().mean(); - double var = st.uni_ls_state().var(); - prec += 1 / var; - num += mean / var; - } - prec = 1 / sig200 + lambda0 * prec; - num = mu00 / sig200 + lambda0 * num; - double mu_n = num / prec; - double sig2_n = 1 / prec; - // Update hyperparameters with posterior random sampling - hypers->mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); - } - - else if (prior->has_ngg_prior()) { - // Get hyperparameters: - // for mu0 - double mu00 = prior->ngg_prior().mean_prior().mean(); - double sig200 = prior->ngg_prior().mean_prior().var(); - // for lambda0 - double alpha00 = prior->ngg_prior().var_scaling_prior().shape(); - double beta00 = prior->ngg_prior().var_scaling_prior().rate(); - // for tau0 - double a00 = prior->ngg_prior().scale_prior().shape(); - double b00 = prior->ngg_prior().scale_prior().rate(); - // Compute posterior hyperparameters - double b_n = 0.0; - double num = 0.0; - double beta_n = 0.0; - for (auto &st : states) { - double mean = st.uni_ls_state().mean(); - double var = st.uni_ls_state().var(); - b_n += 1 / var; - num += mean / var; - beta_n += (hypers->mean - mean) * (hypers->mean - mean) / var; - } - double var = hypers->var_scaling * b_n + 1 / sig200; - b_n += b00; - num = hypers->var_scaling * num + mu00 / sig200; - beta_n = beta00 + 0.5 * beta_n; - double sig_n = 1 / var; - double mu_n = num / var; - double alpha_n = alpha00 + 0.5 * states.size(); - double a_n = a00 + states.size() * hypers->shape; - // Update hyperparameters with posterior random Gibbs sampling - hypers->mean = stan::math::normal_rng(mu_n, sig_n, rng); - hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); - hypers->scale = stan::math::gamma_rng(a_n, b_n, rng); - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -NNIG::State NNIGHierarchy::draw(const NNIG::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - NNIG::State out; - out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - out.mean = stan::math::normal_rng(params.mean, - sqrt(state.var / params.var_scaling), rng); - return out; -} - -void NNIGHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) { - if (add) { - data_sum += datum(0); - data_sum_squares += datum(0) * datum(0); - } else { - data_sum -= datum(0); - data_sum_squares -= datum(0) * datum(0); - } -} - -void NNIGHierarchy::clear_summary_statistics() { - data_sum = 0; - data_sum_squares = 0; -} - -NNIG::Hyperparams NNIGHierarchy::compute_posterior_hypers() const { - // Initialize relevant variables - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - NNIG::Hyperparams post_params; - double y_bar = data_sum / (1.0 * card); // sample mean - double ss = data_sum_squares - card * y_bar * y_bar; - post_params.mean = (hypers->var_scaling * hypers->mean + data_sum) / - (hypers->var_scaling + card); - post_params.var_scaling = hypers->var_scaling + card; - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = hypers->scale + 0.5 * ss + - 0.5 * hypers->var_scaling * card * - (y_bar - hypers->mean) * (y_bar - hypers->mean) / - (card + hypers->var_scaling); - return post_params; -} - -void NNIGHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.var = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -NNIGHierarchy::get_state_proto() const { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.var); - - auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); - return out; -} - -void NNIGHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).nnig_state(); - hypers->mean = hyperscast.mean(); - hypers->var_scaling = hyperscast.var_scaling(); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); -} - -std::shared_ptr -NNIGHierarchy::get_hypers_proto() const { - bayesmix::NIGDistribution hypers_; - hypers_.set_mean(hypers->mean); - hypers_.set_var_scaling(hypers->var_scaling); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - - auto out = std::make_shared(); - out->mutable_nnig_state()->CopyFrom(hypers_); - return out; -} diff --git a/src/hierarchies/nnw_hierarchy.cc b/src/hierarchies/nnw_hierarchy.cc deleted file mode 100644 index e1e58275f..000000000 --- a/src/hierarchies/nnw_hierarchy.cc +++ /dev/null @@ -1,373 +0,0 @@ -#include "nnw_hierarchy.h" - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "matrix.pb.h" -#include "src/utils/distributions.h" -#include "src/utils/eigen_utils.h" -#include "src/utils/proto_utils.h" -#include "src/utils/rng.h" - -double NNWHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return bayesmix::multi_normal_prec_lpdf(datum, state.mean, state.prec_chol, - state.prec_logdet); -} - -double NNWHierarchy::marg_lpdf(const NNW::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - NNW::Hyperparams pred_params = get_predictive_t_parameters(params); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf( - datum, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} - -Eigen::VectorXd NNWHierarchy::like_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - return bayesmix::multi_normal_prec_lpdf_grid( - data, state.mean, state.prec_chol, state.prec_logdet); -} - -Eigen::VectorXd NNWHierarchy::prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - NNW::Hyperparams pred_params = get_predictive_t_parameters(*hypers); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf_grid( - data, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} - -Eigen::VectorXd NNWHierarchy::conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - NNW::Hyperparams pred_params = - get_predictive_t_parameters(compute_posterior_hypers()); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf_grid( - data, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} - -void NNWHierarchy::initialize_state() { - state.mean = hypers->mean; - write_prec_to_state( - hypers->var_scaling * Eigen::MatrixXd::Identity(dim, dim), &state); -} - -void NNWHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); - dim = hypers->mean.size(); - hypers->var_scaling = prior->fixed_values().var_scaling(); - hypers->scale = bayesmix::to_eigen(prior->fixed_values().scale()); - hypers->deg_free = prior->fixed_values().deg_free(); - // Check validity - if (hypers->var_scaling <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - if (dim != hypers->scale.rows()) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - if (hypers->deg_free <= dim - 1) { - throw std::invalid_argument("Degrees of freedom parameter is not valid"); - } - } - - else if (prior->has_normal_mean_prior()) { - // Get hyperparameters - Eigen::VectorXd mu00 = - bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().mean()); - dim = mu00.size(); - Eigen::MatrixXd sigma00 = - bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().var()); - double lambda0 = prior->normal_mean_prior().var_scaling(); - Eigen::MatrixXd tau0 = - bayesmix::to_eigen(prior->normal_mean_prior().scale()); - double nu0 = prior->normal_mean_prior().deg_free(); - // Check validity - unsigned int dim = mu00.size(); - if (sigma00.rows() != dim or tau0.rows() != dim) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - bayesmix::check_spd(sigma00); - if (lambda0 <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - bayesmix::check_spd(tau0); - if (nu0 <= dim - 1) { - throw std::invalid_argument("Degrees of freedom parameter is not valid"); - } - // Set initial values - hypers->mean = mu00; - hypers->var_scaling = lambda0; - hypers->scale = tau0; - hypers->deg_free = nu0; - } - - else if (prior->has_ngiw_prior()) { - // Get hyperparameters: - // for mu0 - Eigen::VectorXd mu00 = - bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); - dim = mu00.size(); - Eigen::MatrixXd sigma00 = - bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); - // for lambda0 - double alpha00 = prior->ngiw_prior().var_scaling_prior().shape(); - double beta00 = prior->ngiw_prior().var_scaling_prior().rate(); - // for tau0 - double nu00 = prior->ngiw_prior().scale_prior().deg_free(); - Eigen::MatrixXd tau00 = - bayesmix::to_eigen(prior->ngiw_prior().scale_prior().scale()); - // for nu0 - double nu0 = prior->ngiw_prior().deg_free(); - // Check validity: - // dimensionality - if (sigma00.rows() != dim or tau00.rows() != dim) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - // for mu0 - bayesmix::check_spd(sigma00); - // for lambda0 - if (alpha00 <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (beta00 <= 0) { - throw std::invalid_argument("Rate parameter must be > 0"); - } - // for tau0 - if (nu00 <= 0) { - throw std::invalid_argument("Degrees of freedom parameter must be > 0"); - } - bayesmix::check_spd(tau00); - // check nu0 - if (nu0 <= dim - 1) { - throw std::invalid_argument("Degrees of freedom parameter is not valid"); - } - // Set initial values - hypers->mean = mu00; - hypers->var_scaling = alpha00 / beta00; - hypers->scale = tau00 / (nu00 + dim + 1); - hypers->deg_free = nu0; - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } - hypers->scale_inv = stan::math::inverse_spd(hypers->scale); - hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); -} - -void NNWHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } - - else if (prior->has_normal_mean_prior()) { - // Get hyperparameters - Eigen::VectorXd mu00 = - bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().mean()); - Eigen::MatrixXd sigma00 = - bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().var()); - double lambda0 = prior->normal_mean_prior().var_scaling(); - // Compute posterior hyperparameters - Eigen::MatrixXd sigma00inv = stan::math::inverse_spd(sigma00); - Eigen::MatrixXd prec = Eigen::MatrixXd::Zero(dim, dim); - Eigen::VectorXd num = Eigen::MatrixXd::Zero(dim, 1); - for (auto &st : states) { - Eigen::MatrixXd prec_i = bayesmix::to_eigen(st.multi_ls_state().prec()); - prec += prec_i; - num += prec_i * bayesmix::to_eigen(st.multi_ls_state().mean()); - } - prec = hypers->var_scaling * prec + sigma00inv; - num = hypers->var_scaling * num + sigma00inv * mu00; - Eigen::VectorXd mu_n = prec.llt().solve(num); - // Update hyperparameters with posterior sampling - hypers->mean = stan::math::multi_normal_prec_rng(mu_n, prec, rng); - } - - else if (prior->has_ngiw_prior()) { - // Get hyperparameters: - // for mu0 - Eigen::VectorXd mu00 = - bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); - Eigen::MatrixXd sigma00 = - bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); - // for lambda0 - double alpha00 = prior->ngiw_prior().var_scaling_prior().shape(); - double beta00 = prior->ngiw_prior().var_scaling_prior().rate(); - // for tau0 - double nu00 = prior->ngiw_prior().scale_prior().deg_free(); - Eigen::MatrixXd tau00 = - bayesmix::to_eigen(prior->ngiw_prior().scale_prior().scale()); - // Compute posterior hyperparameters - Eigen::MatrixXd sigma00inv = stan::math::inverse_spd(sigma00); - Eigen::MatrixXd tau_n = Eigen::MatrixXd::Zero(dim, dim); - Eigen::VectorXd num = Eigen::MatrixXd::Zero(dim, 1); - double beta_n = 0.0; - for (auto &st : states) { - Eigen::VectorXd mean = bayesmix::to_eigen(st.multi_ls_state().mean()); - Eigen::MatrixXd prec = bayesmix::to_eigen(st.multi_ls_state().prec()); - tau_n += prec; - num += prec * mean; - beta_n += - (hypers->mean - mean).transpose() * prec * (hypers->mean - mean); - } - Eigen::MatrixXd prec_n = hypers->var_scaling * tau_n + sigma00inv; - tau_n += tau00; - num = hypers->var_scaling * num + sigma00inv * mu00; - beta_n = beta00 + 0.5 * beta_n; - Eigen::MatrixXd sig_n = stan::math::inverse_spd(prec_n); - Eigen::VectorXd mu_n = sig_n * num; - double alpha_n = alpha00 + 0.5 * states.size(); - double nu_n = nu00 + states.size() * hypers->deg_free; - // Update hyperparameters with posterior random Gibbs sampling - hypers->mean = stan::math::multi_normal_rng(mu_n, sig_n, rng); - hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); - hypers->scale = stan::math::inv_wishart_rng(nu_n, tau_n, rng); - hypers->scale_inv = stan::math::inverse_spd(hypers->scale); - hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -NNW::State NNWHierarchy::draw(const NNW::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - Eigen::MatrixXd tau_new = - stan::math::wishart_rng(params.deg_free, params.scale, rng); - // Update state - NNW::State out; - out.mean = stan::math::multi_normal_prec_rng( - params.mean, tau_new * params.var_scaling, rng); - write_prec_to_state(tau_new, &out); - return out; -} - -void NNWHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) { - if (add) { - data_sum += datum.transpose(); - data_sum_squares += datum.transpose() * datum; - } else { - data_sum -= datum.transpose(); - data_sum_squares -= datum.transpose() * datum; - } -} - -void NNWHierarchy::clear_summary_statistics() { - data_sum = Eigen::VectorXd::Zero(dim); - data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); -} - -NNW::Hyperparams NNWHierarchy::compute_posterior_hypers() const { - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - NNW::Hyperparams post_params; - post_params.var_scaling = hypers->var_scaling + card; - post_params.deg_free = hypers->deg_free + card; - Eigen::VectorXd mubar = data_sum.array() / card; // sample mean - post_params.mean = (hypers->var_scaling * hypers->mean + card * mubar) / - (hypers->var_scaling + card); - // Compute tau_n - Eigen::MatrixXd tau_temp = - data_sum_squares - card * mubar * mubar.transpose(); - tau_temp += (card * hypers->var_scaling / (card + hypers->var_scaling)) * - (mubar - hypers->mean) * (mubar - hypers->mean).transpose(); - post_params.scale_inv = tau_temp + hypers->scale_inv; - post_params.scale = stan::math::inverse_spd(post_params.scale_inv); - post_params.scale_chol = - Eigen::LLT(post_params.scale).matrixU(); - return post_params; -} - -void NNWHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = to_eigen(statecast.multi_ls_state().mean()); - state.prec = to_eigen(statecast.multi_ls_state().prec()); - state.prec_chol = to_eigen(statecast.multi_ls_state().prec_chol()); - Eigen::VectorXd diag = state.prec_chol.diagonal(); - state.prec_logdet = 2 * log(diag.array()).sum(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -NNWHierarchy::get_state_proto() const { - bayesmix::MultiLSState state_; - bayesmix::to_proto(state.mean, state_.mutable_mean()); - bayesmix::to_proto(state.prec, state_.mutable_prec()); - bayesmix::to_proto(state.prec_chol, state_.mutable_prec_chol()); - - auto out = std::make_shared(); - out->mutable_multi_ls_state()->CopyFrom(state_); - return out; -} - -void NNWHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).nnw_state(); - hypers->mean = to_eigen(hyperscast.mean()); - hypers->var_scaling = hyperscast.var_scaling(); - hypers->deg_free = hyperscast.deg_free(); - hypers->scale = to_eigen(hyperscast.scale()); -} - -std::shared_ptr -NNWHierarchy::get_hypers_proto() const { - bayesmix::NWDistribution hypers_; - bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); - hypers_.set_var_scaling(hypers->var_scaling); - hypers_.set_deg_free(hypers->deg_free); - bayesmix::to_proto(hypers->scale, hypers_.mutable_scale()); - - auto out = std::make_shared(); - out->mutable_nnw_state()->CopyFrom(hypers_); - return out; -} - -void NNWHierarchy::write_prec_to_state(const Eigen::MatrixXd &prec_, - NNW::State *out) { - out->prec = prec_; - // Update prec utilities - out->prec_chol = Eigen::LLT(prec_).matrixU(); - Eigen::VectorXd diag = out->prec_chol.diagonal(); - out->prec_logdet = 2 * log(diag.array()).sum(); -} - -NNW::Hyperparams NNWHierarchy::get_predictive_t_parameters( - const NNW::Hyperparams ¶ms) const { - // Compute dof and scale of marginal distribution - double nu_n = params.deg_free - dim + 1; - double coeff = (params.var_scaling + 1) / (params.var_scaling * nu_n); - Eigen::MatrixXd scale_chol_n = params.scale_chol / std::sqrt(coeff); - - NNW::Hyperparams out; - out.mean = params.mean; - out.deg_free = nu_n; - out.scale_chol = scale_chol_n; - return out; -} diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h deleted file mode 100644 index 1b149d422..000000000 --- a/src/hierarchies/nnw_hierarchy.h +++ /dev/null @@ -1,168 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Normal Normal-Wishart hierarchy for multivariate data. - -//! This class represents a hierarchy, i.e. a cluster, whose multivariate data -//! are distributed according to a multinomial normal likelihood, the -//! parameters of which have a Normal-Wishart centering distribution. That is: -//! f(x_i|mu,tau) = N(mu,tau^{-1}) -//! (mu,tau) ~ NW(mu0, lambda0, tau0, nu0) -//! The state is composed of mean and precision matrix. The Cholesky factor and -//! log-determinant of the latter are also included in the container for -//! efficiency reasons. The state's hyperparameters, contained in the Hypers -//! object, are (mu0, lambda0, tau0, nu0), which are respectively vector, -//! scalar, matrix, and scalar. Note that this hierarchy is conjugate, thus the -//! marginal distribution is available in closed form. For more information, -//! please refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace NNW { -//! Custom container for State values -struct State { - Eigen::VectorXd mean; - Eigen::MatrixXd prec; - Eigen::MatrixXd prec_chol; - double prec_logdet; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - Eigen::VectorXd mean; - double var_scaling; - double deg_free; - Eigen::MatrixXd scale; - Eigen::MatrixXd scale_inv; - Eigen::MatrixXd scale_chol; -}; -} // namespace NNW - -class NNWHierarchy - : public ConjugateHierarchy { - public: - NNWHierarchy() = default; - ~NNWHierarchy() = default; - - // EVALUATION FUNCTIONS FOR GRIDS OF POINTS - //! Evaluates the log-likelihood of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = - Eigen::MatrixXd(0, 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - NNW::State draw(const NNW::Hyperparams ¶ms); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::NNW; - } - - //! Computes and return posterior hypers given data currently in this cluster - NNW::Hyperparams compute_posterior_hypers() const; - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return true; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double marg_lpdf(const NNW::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) override; - - //! Writes prec and its utilities to the given state object by pointer - void write_prec_to_state(const Eigen::MatrixXd &prec_, NNW::State *out); - - //! Returns parameters for the predictive Student's t distribution - NNW::Hyperparams get_predictive_t_parameters( - const NNW::Hyperparams ¶ms) const; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Dimension of data space - unsigned int dim; - - //! Sum of data points currently belonging to the cluster - Eigen::VectorXd data_sum; - - //! Sum of squared data points currently belonging to the cluster - Eigen::MatrixXd data_sum_squares; -}; - -#endif // BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ diff --git a/src/hierarchies/nnxig_hierarchy.cc b/src/hierarchies/nnxig_hierarchy.cc deleted file mode 100644 index f7bc62a16..000000000 --- a/src/hierarchies/nnxig_hierarchy.cc +++ /dev/null @@ -1,152 +0,0 @@ -#include "nnxig_hierarchy.h" - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/rng.h" - -double NNxIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); -} - -void NNxIGHierarchy::initialize_state() { - state.mean = hypers->mean; - state.var = hypers->scale / (hypers->shape + 1); -} - -void NNxIGHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = prior->fixed_values().mean(); - hypers->var = prior->fixed_values().var(); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - - // Check validity - if (hypers->var <= 0) { - throw std::invalid_argument("Variance parameter must be > 0"); - } - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void NNxIGHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) { - if (add) { - data_sum += datum(0); - data_sum_squares += datum(0) * datum(0); - } else { - data_sum -= datum(0); - data_sum_squares -= datum(0) * datum(0); - } -} - -void NNxIGHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void NNxIGHierarchy::clear_summary_statistics() { - data_sum = 0; - data_sum_squares = 0; -} - -void NNxIGHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.var = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -void NNxIGHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).nnxig_state(); - hypers->mean = hyperscast.mean(); - hypers->var = hyperscast.var(); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); -} - -std::shared_ptr -NNxIGHierarchy::get_state_proto() const { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.var); - - auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); - return out; -} - -std::shared_ptr -NNxIGHierarchy::get_hypers_proto() const { - bayesmix::NxIGDistribution hypers_; - hypers_.set_mean(hypers->mean); - hypers_.set_var(hypers->var); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - - auto out = std::make_shared(); - out->mutable_nnxig_state()->CopyFrom(hypers_); - return out; -} - -void NNxIGHierarchy::sample_full_cond(bool update_params) { - if (this->card == 0) { - // No posterior update possible - sample_prior(); - } else { - NNxIG::Hyperparams params = - update_params ? compute_posterior_hypers() : posterior_hypers; - state = draw(params); - } -} - -NNxIG::State NNxIGHierarchy::draw(const NNxIG::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - NNxIG::State out; - out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - out.mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); - return out; -} - -NNxIG::Hyperparams NNxIGHierarchy::compute_posterior_hypers() const { - // Initialize relevant variables - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - NNxIG::Hyperparams post_params; - double var_y = data_sum_squares - 2 * state.mean * data_sum + - card * state.mean * state.mean; - post_params.mean = (hypers->var * data_sum + state.var * hypers->mean) / - (card * hypers->var + state.var); - post_params.var = - (state.var * hypers->var) / (card * hypers->var + state.var); - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = hypers->scale + 0.5 * var_y; - return post_params; -} - -void NNxIGHierarchy::save_posterior_hypers() { - posterior_hypers = compute_posterior_hypers(); -} diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h deleted file mode 100644 index de5e878f6..000000000 --- a/src/hierarchies/nnxig_hierarchy.h +++ /dev/null @@ -1,120 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "base_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Non Conjugate Normal Normal-InverseGamma hierarchy for univariate data. - -//! This class represents a hierarchical model where data are distributed -//! according to a normal likelihood, the parameters of which have a -//! Normal-InverseGamma centering distribution. That is: -//! f(x_i|mu,sig) = N(mu,sig^2) -//! mu ~ N(mu0, sigma0) -//! sig^2 ~ IG(alpha0, beta0) -//! The state is composed of mean and variance. The state hyperparameters, -//! contained in the Hypers object, are (mu0, sigma0, alpha0, beta0), all -//! scalar values. Note that this hierarchy is non conjugate. - -namespace NNxIG { -//! Custom container for State values -struct State { - double mean, var; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double mean, var, shape, scale; -}; - -}; // namespace NNxIG - -class NNxIGHierarchy - : public BaseHierarchy { - public: - NNxIGHierarchy() = default; - ~NNxIGHierarchy() = default; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - NNxIG::State draw(const NNxIG::Hyperparams ¶ms); - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(bool update_params = true) override; - - //! Saves posterior hyperparameters to the corresponding class member - void save_posterior_hypers(); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::NNxIG; - } - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Computes and return posterior hypers given data currently in this cluster - NNxIG::Hyperparams compute_posterior_hypers() const; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) override; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Sum of data points currently belonging to the cluster - double data_sum = 0; - - //! Sum of squared data points currently belonging to the cluster - double data_sum_squares = 0; -}; - -#endif // BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ From 713d52d4a61e25a4d3c409abaf82b3a7f79ee840 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 20 Jan 2022 22:28:08 +0100 Subject: [PATCH 051/317] Enabled all test suites commenting non-implemented hierarchies --- test/CMakeLists.txt | 24 ++++----- test/lpdf.cc | 115 ++++++++++++++++++++++---------------------- test/write_proto.cc | 38 +++++++-------- 3 files changed, 88 insertions(+), 89 deletions(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5a238990f..ea406f1b3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,20 +16,20 @@ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) add_executable(test_bayesmix $ - # write_proto.cc - # proto_utils.cc - hierarchies.cc - # lpdf.cc - # priors.cc - # eigen_utils.cc - # distributions.cc - # semi_hdp.cc - # collectors.cc - # runtime.cc - # rng.cc - # logit_sb.cc + write_proto.cc + proto_utils.cc likelihoods.cc prior_models.cc + hierarchies.cc + lpdf.cc + # priors.cc // OLD, USEREI prior_models.cc + eigen_utils.cc + distributions.cc + semi_hdp.cc + collectors.cc + runtime.cc + rng.cc + logit_sb.cc ) target_include_directories(test_bayesmix PUBLIC ${INCLUDE_PATHS}) diff --git a/test/lpdf.cc b/test/lpdf.cc index 5a97220a0..f868aba05 100644 --- a/test/lpdf.cc +++ b/test/lpdf.cc @@ -5,9 +5,9 @@ #include #include "algorithm_state.pb.h" -#include "src/hierarchies/lin_reg_uni_hierarchy.h" +// #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" -#include "src/hierarchies/nnw_hierarchy.h" +// #include "src/hierarchies/nnw_hierarchy.h" #include "src/utils/proto_utils.h" TEST(lpdf, nnig) { @@ -137,63 +137,62 @@ TEST(lpdf, nnig) { // ASSERT_DOUBLE_EQ(marg, marg_murphy); // } -TEST(lpdf, lin_reg_uni) { - // Create hierarchy objects - LinRegUniHierarchy hier; - bayesmix::LinRegUniPrior prior; - int dim = 3; - - // Generate data - Eigen::VectorXd datum(1); - datum << 1.5; - Eigen::VectorXd cov = Eigen::VectorXd::Random(dim); - - // Create parameters, both Eigen and proto - Eigen::VectorXd mu0(dim); - for (int i = 0; i < dim; i++) { - mu0(i) = 2 * i; - } - bayesmix::Vector mu0_proto; - bayesmix::to_proto(mu0, &mu0_proto); - auto Lambda0 = Eigen::MatrixXd::Identity(dim, dim); - bayesmix::Matrix Lambda0_proto; - bayesmix::to_proto(Lambda0, &Lambda0_proto); - double alpha0 = 2.0; - double beta0 = 2.0; - // Set parameters - *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; - *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; - prior.mutable_fixed_values()->set_shape(alpha0); - prior.mutable_fixed_values()->set_scale(beta0); - // Initialize hierarchy - hier.get_mutable_prior()->CopyFrom(prior); - hier.initialize(); +// TEST(lpdf, lin_reg_uni) { +// // Create hierarchy objects +// LinRegUniHierarchy hier; +// bayesmix::LinRegUniPrior prior; +// int dim = 3; + +// // Generate data +// Eigen::VectorXd datum(1); +// datum << 1.5; +// Eigen::VectorXd cov = Eigen::VectorXd::Random(dim); + +// // Create parameters, both Eigen and proto +// Eigen::VectorXd mu0(dim); +// for (int i = 0; i < dim; i++) { +// mu0(i) = 2 * i; +// } +// bayesmix::Vector mu0_proto; +// bayesmix::to_proto(mu0, &mu0_proto); +// auto Lambda0 = Eigen::MatrixXd::Identity(dim, dim); +// bayesmix::Matrix Lambda0_proto; +// bayesmix::to_proto(Lambda0, &Lambda0_proto); +// double alpha0 = 2.0; +// double beta0 = 2.0; +// // Set parameters +// *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; +// *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; +// prior.mutable_fixed_values()->set_shape(alpha0); +// prior.mutable_fixed_values()->set_scale(beta0); +// // Initialize hierarchy +// hier.get_mutable_prior()->CopyFrom(prior); +// hier.initialize(); - // Compute prior parameters - Eigen::VectorXd mean = mu0; - double var = beta0 / (alpha0 + 1); +// // Compute prior parameters +// Eigen::VectorXd mean = mu0; +// double var = beta0 / (alpha0 + 1); - // Compute posterior parameters - Eigen::MatrixXd Lambda_n = Lambda0 + cov * cov.transpose(); - Eigen::VectorXd mu_n = - stan::math::inverse_spd(Lambda_n) * (datum(0) * cov + Lambda0 * mu0); - double alpha_n = alpha0 + 0.5; - double beta_n = - beta0 + 0.5 * (datum(0) * datum(0) + mu0.transpose() * Lambda0 * mu0 - - mu_n.transpose() * Lambda_n * mu_n); - // Compute pieces - double prior1 = stan::math::inv_gamma_lpdf(var, alpha0, beta0); - double prior2 = stan::math::multi_normal_prec_lpdf(mean, mu0, Lambda0 / var); - double pr = prior1 + prior2; - double like = hier.get_like_lpdf(datum, cov); - double post1 = stan::math::inv_gamma_lpdf(var, alpha_n, beta_n); - double post2 = - stan::math::multi_normal_prec_lpdf(mean, mu_n, Lambda_n / var); - double post = post1 + post2; +// // Compute posterior parameters +// Eigen::MatrixXd Lambda_n = Lambda0 + cov * cov.transpose(); +// Eigen::VectorXd mu_n = +// stan::math::inverse_spd(Lambda_n) * (datum(0) * cov + Lambda0 * mu0); +// double alpha_n = alpha0 + 0.5; +// double beta_n = +// beta0 + 0.5 * (datum(0) * datum(0) + mu0.transpose() * Lambda0 * mu0 - +// mu_n.transpose() * Lambda_n * mu_n); +// // Compute pieces +// double prior1 = stan::math::inv_gamma_lpdf(var, alpha0, beta0); +// double prior2 = stan::math::multi_normal_prec_lpdf(mean, mu0, Lambda0 / +// var); double pr = prior1 + prior2; double like = hier.get_like_lpdf(datum, +// cov); double post1 = stan::math::inv_gamma_lpdf(var, alpha_n, beta_n); +// double post2 = +// stan::math::multi_normal_prec_lpdf(mean, mu_n, Lambda_n / var); +// double post = post1 + post2; - // Bayes: logmarg(x) = logprior(phi) + loglik(x|phi) - logpost(phi|x) - double sum = pr + like - post; - double marg = hier.prior_pred_lpdf(datum, cov); +// // Bayes: logmarg(x) = logprior(phi) + loglik(x|phi) - logpost(phi|x) +// double sum = pr + like - post; +// double marg = hier.prior_pred_lpdf(datum, cov); - ASSERT_FLOAT_EQ(sum, marg); -} +// ASSERT_FLOAT_EQ(sum, marg); +// } diff --git a/test/write_proto.cc b/test/write_proto.cc index 60b60ba9f..3be320ec9 100644 --- a/test/write_proto.cc +++ b/test/write_proto.cc @@ -3,7 +3,7 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" #include "src/hierarchies/nnig_hierarchy.h" -#include "src/hierarchies/nnw_hierarchy.h" +// #include "src/hierarchies/nnw_hierarchy.h" #include "src/utils/proto_utils.h" TEST(set_state, uni_ls) { @@ -47,25 +47,25 @@ TEST(write_proto, uni_ls) { ASSERT_EQ(var, out_var); } -TEST(set_state, multi_ls) { - Eigen::VectorXd mean = Eigen::VectorXd::Ones(5); - Eigen::MatrixXd prec = Eigen::MatrixXd::Identity(5, 5); - prec(1, 1) = 10.0; +// TEST(set_state, multi_ls) { +// Eigen::VectorXd mean = Eigen::VectorXd::Ones(5); +// Eigen::MatrixXd prec = Eigen::MatrixXd::Identity(5, 5); +// prec(1, 1) = 10.0; - bayesmix::MultiLSState curr; - bayesmix::to_proto(mean, curr.mutable_mean()); - bayesmix::to_proto(prec, curr.mutable_prec()); +// bayesmix::MultiLSState curr; +// bayesmix::to_proto(mean, curr.mutable_mean()); +// bayesmix::to_proto(prec, curr.mutable_prec()); - ASSERT_EQ(curr.mean().data(0), 1.0); - ASSERT_EQ(curr.prec().data(0), 1.0); - ASSERT_EQ(curr.prec().data(6), 10.0); +// ASSERT_EQ(curr.mean().data(0), 1.0); +// ASSERT_EQ(curr.prec().data(0), 1.0); +// ASSERT_EQ(curr.prec().data(6), 10.0); - bayesmix::AlgorithmState::ClusterState clusval_in; - clusval_in.mutable_multi_ls_state()->CopyFrom(curr); - NNWHierarchy cluster; - cluster.set_state_from_proto(clusval_in); +// bayesmix::AlgorithmState::ClusterState clusval_in; +// clusval_in.mutable_multi_ls_state()->CopyFrom(curr); +// NNWHierarchy cluster; +// cluster.set_state_from_proto(clusval_in); - ASSERT_EQ(curr.mean().data(0), cluster.get_state().mean(0)); - ASSERT_EQ(curr.prec().data(0), cluster.get_state().prec(0, 0)); - ASSERT_EQ(curr.prec().data(6), cluster.get_state().prec(1, 1)); -} +// ASSERT_EQ(curr.mean().data(0), cluster.get_state().mean(0)); +// ASSERT_EQ(curr.prec().data(0), cluster.get_state().prec(0, 0)); +// ASSERT_EQ(curr.prec().data(6), cluster.get_state().prec(1, 1)); +// } From 858671f40c006ac628d2c77c49ec722aa1c83875 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 20 Jan 2022 22:28:44 +0100 Subject: [PATCH 052/317] Comment-out non implemented hierarchies --- src/hierarchies/load_hierarchies.h | 30 +++++++++++++++--------------- src/includes.h | 6 +++--- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index d18f8e8d6..982960f0a 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -6,10 +6,10 @@ #include "abstract_hierarchy.h" #include "hierarchy_id.pb.h" -#include "lin_reg_uni_hierarchy.h" +// #include "lin_reg_uni_hierarchy.h" #include "nnig_hierarchy.h" -#include "nnw_hierarchy.h" -#include "nnxig_hierarchy.h" +// #include "nnw_hierarchy.h" +// #include "nnxig_hierarchy.h" #include "src/runtime/factory.h" //! Loads all available `Hierarchy` objects into the appropriate factory, so @@ -26,20 +26,20 @@ __attribute__((constructor)) static void load_hierarchies() { Builder NNIGbuilder = []() { return std::make_shared(); }; - Builder NNxIGbuilder = []() { - return std::make_shared(); - }; - Builder NNWbuilder = []() { - return std::make_shared(); - }; - Builder LinRegUnibuilder = []() { - return std::make_shared(); - }; + // Builder NNxIGbuilder = []() { + // return std::make_shared(); + // }; + // Builder NNWbuilder = []() { + // return std::make_shared(); + // }; + // Builder LinRegUnibuilder = []() { + // return std::make_shared(); + // }; factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder); - factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); - factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); - factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); + // factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); + // factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); + // factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); } #endif // BAYESMIX_HIERARCHIES_LOAD_HIERARCHIES_H_ diff --git a/src/includes.h b/src/includes.h index 077683610..39e03bea3 100644 --- a/src/includes.h +++ b/src/includes.h @@ -9,11 +9,11 @@ #include "algorithms/neal8_algorithm.h" #include "collectors/file_collector.h" #include "collectors/memory_collector.h" -#include "hierarchies/lin_reg_uni_hierarchy.h" +// #include "hierarchies/lin_reg_uni_hierarchy.h" #include "hierarchies/load_hierarchies.h" #include "hierarchies/nnig_hierarchy.h" -#include "hierarchies/nnw_hierarchy.h" -#include "hierarchies/nnxig_hierarchy.h" +// #include "hierarchies/nnw_hierarchy.h" +// #include "hierarchies/nnxig_hierarchy.h" #include "mixings/dirichlet_mixing.h" #include "mixings/load_mixings.h" #include "mixings/logit_sb_mixing.h" From 751e1731b9e385b55ff93f07103786b0bad36cef Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 20 Jan 2022 22:29:44 +0100 Subject: [PATCH 053/317] Improved exception handling --- src/hierarchies/likelihoods/abstract_likelihood.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 8a7a8841c..aea01e267 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -21,8 +21,9 @@ class AbstractLikelihood { // IMPLEMENTED in BaseLikelihood virtual std::shared_ptr clone() const = 0; - double lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { + double lpdf( + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { if (is_dependent()) { return compute_lpdf(datum, covariate); } else { @@ -30,13 +31,13 @@ class AbstractLikelihood { } } - // AGGIUNGERE CLUST_LPDF (CHE VALUTA LA LIKELIHOOD CONGIUNTA SU TUTTO IL - // CLUSTER) - virtual Eigen::VectorXd lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const = 0; + // AGGIUNGERE CLUST_LPDF (CHE VALUTA LA LIKELIHOOD CONGIUNTA SU TUTTO IL + // CLUSTER) + virtual bool is_multivariate() const = 0; virtual bool is_dependent() const = 0; From 42fc1a30631c1baa8ab9c14d124c323a84e6a711 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 20 Jan 2022 22:30:19 +0100 Subject: [PATCH 054/317] Improved API (ONGOING) --- src/hierarchies/updaters/nnig_updater.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index 074f82f3e..f771f7245 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -12,6 +12,7 @@ class NNIGUpdater { ~NNIGUpdater() = default; std::shared_ptr clone() const; + bool is_conjugate() const { return true; }; void draw(UniNormLikelihood& like, NIGPriorModel& prior); void initialize(UniNormLikelihood& like, NIGPriorModel& prior); void compute_posterior_hypers(UniNormLikelihood& like, NIGPriorModel& prior); From be84b35e356a1f2605e7231bc9dfa602f75687ae Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 20 Jan 2022 22:31:25 +0100 Subject: [PATCH 055/317] Improved API --- src/hierarchies/priors/base_prior_model.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index a1c0aa980..61bb680ed 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -27,6 +27,8 @@ class BasePriorModel : public AbstractPriorModel { HyperParams get_hypers() const { return *hypers; } + HyperParams get_posterior_hypers() const { return *post_hypers; } + void set_posterior_hypers(const HyperParams &_post_hypers) { post_hypers = std::make_shared(_post_hypers); }; From 4b35f18b3ca787b18782f131a38ac18ca26e4c64 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 20 Jan 2022 22:32:11 +0100 Subject: [PATCH 056/317] NNIG Hierarchy implemented as composition (BUGS NOT CHECKED) --- src/hierarchies/abstract_hierarchy.h | 36 +++++-- src/hierarchies/base_hierarchy.h | 137 +++++++++++++++++++++++++-- src/hierarchies/nnig_hierarchy.h | 10 +- 3 files changed, 165 insertions(+), 18 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 9b54c1fa7..fd642753f 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -72,8 +72,12 @@ class AbstractHierarchy { virtual double prior_pred_lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - throw std::runtime_error( - "Cannot call prior_pred_lpdf() from a non-conjugate hierarchy"); + if (is_conjugate()) { + throw std::runtime_error("prior_pred_lpdf() not implemented yet"); + } else { + throw std::runtime_error( + "Cannot call prior_pred_lpdf() from a non-conjugate hierarchy"); + } } //! Evaluates the log-conditional predictive distr. of data in a single point @@ -83,8 +87,13 @@ class AbstractHierarchy { virtual double conditional_pred_lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - throw std::runtime_error( - "Cannot call conditional_pred_lpdf() from a non-conjugate hierarchy"); + if (is_conjugate()) { + throw std::runtime_error("conditional_pred_lpdf() not implemented yet"); + } else { + throw std::runtime_error( + "Cannot call conditional_pred_lpdf() from a non-conjugate " + "hierarchy"); + } } // EVALUATION FUNCTIONS FOR GRIDS OF POINTS @@ -103,8 +112,12 @@ class AbstractHierarchy { virtual Eigen::VectorXd prior_pred_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const { - throw std::runtime_error( - "Cannot call prior_pred_lpdf_grid() from a non-conjugate hierarchy"); + if (is_conjugate()) { + throw std::runtime_error("prior_pred_lpdf_grid() not yet implemented"); + } else { + throw std::runtime_error( + "Cannot call prior_pred_lpdf_grid() from a non-conjugate hierarchy"); + } } //! Evaluates the log-prior predictive distr. of data in a grid of points @@ -114,9 +127,14 @@ class AbstractHierarchy { virtual Eigen::VectorXd conditional_pred_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const { - throw std::runtime_error( - "Cannot call conditional_pred_lpdf_grid() from a non-conjugate " - "hierarchy"); + if (is_conjugate()) { + throw std::runtime_error( + "conditional_pred_lpdf_grid() not yet implemented"); + } else { + throw std::runtime_error( + "Cannot call conditional_pred_lpdf_grid() from a non-conjugate " + "hierarchy"); + } } // SAMPLING FUNCTIONS diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 4e4407795..8cd7b64fa 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -38,6 +38,7 @@ class BaseHierarchy : public AbstractHierarchy { std::shared_ptr updater = std::make_shared(); public: + using HyperParams = decltype(prior->get_hypers()); BaseHierarchy() = default; ~BaseHierarchy() = default; @@ -55,12 +56,90 @@ class BaseHierarchy : public AbstractHierarchy { return out; }; + double like_lpdf(const Eigen::RowVectorXd &datum) const override { + return like->lpdf(datum); + } + Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const override { return like->lpdf_grid(data, covariates); }; + double get_marg_lpdf( + const HyperParams ¶ms, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { + if (this->is_dependent()) { + return marg_lpdf(params, datum, covariate); + } else { + return marg_lpdf(params, datum); + } + } + + double prior_pred_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)) const override { + return get_marg_lpdf(prior->get_hypers(), datum, covariate); + } + + Eigen::VectorXd prior_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { + Eigen::VectorXd lpdf(data.rows()); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->prior_pred_lpdf( + data.row(i), covariates.row(i)); + } + } + return lpdf; + } + + double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate = + Eigen::RowVectorXd(0)) const override { + return get_marg_lpdf(prior->get_posterior_hypers(), datum, covariate); + } + + Eigen::VectorXd conditional_pred_lpdf_grid( + const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { + Eigen::VectorXd lpdf(data.rows()); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->conditional_pred_lpdf( + data.row(i), Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->conditional_pred_lpdf( + data.row(i), covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->conditional_pred_lpdf( + data.row(i), covariates.row(i)); + } + } + return lpdf; + } + void sample_prior() override { like->set_state_from_proto(*prior->sample(false)); }; @@ -69,11 +148,31 @@ class BaseHierarchy : public AbstractHierarchy { updater->draw(*like, *prior); }; - // DA IMPLEMENTARE !!! void sample_full_cond( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override { - return; + like->clear_data(); + like->clear_summary_statistics(); + if (covariates.cols() == 0) { + // Pass null value as covariate + for (int i = 0; i < data.rows(); i++) { + static_cast(this)->add_datum(i, data.row(i), false, + Eigen::RowVectorXd(0)); + } + } else if (covariates.rows() == 1) { + // Use unique covariate + for (int i = 0; i < data.rows(); i++) { + static_cast(this)->add_datum(i, data.row(i), false, + covariates.row(0)); + } + } else { + // Use different covariates + for (int i = 0; i < data.rows(); i++) { + static_cast(this)->add_datum(i, data.row(i), false, + covariates.row(i)); + } + } + static_cast(this)->sample_full_cond(true); }; void update_hypers(const std::vector @@ -81,6 +180,10 @@ class BaseHierarchy : public AbstractHierarchy { prior->update_hypers(states); }; + auto get_state() const -> decltype(like->get_state()) { + return like->get_state(); + }; + int get_card() const override { return like->get_card(); }; double get_log_card() const override { return like->get_log_card(); }; @@ -108,22 +211,20 @@ class BaseHierarchy : public AbstractHierarchy { prior->set_hypers_from_proto(state_); }; - // DA SISTEMARE void add_datum( const int id, const Eigen::RowVectorXd &datum, const bool update_params = false, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override { - // gestire update_params !! like->add_datum(id, datum, covariate); + if (update_params) updater->compute_posterior_hypers(*like, *prior); }; - // DA SISTEMARE void remove_datum( const int id, const Eigen::RowVectorXd &datum, const bool update_params = false, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override { - // gestire update_params !! like->remove_datum(id, datum, covariate); + if (update_params) updater->compute_posterior_hypers(*like, *prior); }; void initialize() override { updater->initialize(*like, *prior); }; @@ -131,6 +232,30 @@ class BaseHierarchy : public AbstractHierarchy { bool is_multivariate() const override { return like->is_multivariate(); }; bool is_dependent() const override { return like->is_dependent(); }; + + bool is_conjugate() const override { return updater->is_conjugate(); }; + + protected: + virtual double marg_lpdf(const HyperParams ¶ms, + const Eigen::RowVectorXd &datum) const { + if (!is_conjugate()) { + throw std::runtime_error( + "Call marg_lpdf() for a non-conjugate hierarchy"); + } else { + throw std::runtime_error("marg_lpdf() not yet implemented"); + } + } + + virtual double marg_lpdf(const HyperParams ¶ms, + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + if (!is_conjugate()) { + throw std::runtime_error( + "Call marg_lpdf() for a non-conjugate hierarchy"); + } else { + throw std::runtime_error("marg_lpdf() not yet implemented"); + } + } }; // //! Returns an independent, data-less copy of this object diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index 1271159fe..c2e7aa0ab 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -27,9 +27,13 @@ class NNIGHierarchy : public BaseHierarchy Date: Fri, 21 Jan 2022 09:22:00 +0100 Subject: [PATCH 057/317] Minor changes --- src/hierarchies/likelihoods/uni_norm_likelihood.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index ac2539098..9b967e4bb 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -25,11 +25,9 @@ void UniNormLikelihood::set_state_from_proto( std::shared_ptr UniNormLikelihood::get_state_proto() const { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.var); auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); + out->mutable_uni_ls_state()->set_mean(state.mean); + out->mutable_uni_ls_state()->set_var(state.var); return out; } From 7820fab19b56d8d0ba6e4278805c47cb7f5b08bf Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 21 Jan 2022 21:15:23 +0100 Subject: [PATCH 058/317] Minor changes --- src/hierarchies/updaters/nnig_updater.cc | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index cf4c0852f..23e4063aa 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -25,15 +25,18 @@ void NNIGUpdater::initialize(UniNormLikelihood &like, NIGPriorModel &prior) { void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, NIGPriorModel &prior) { + // std::cout << "NNIGUpdater::compute_posterior_hypers()" << std::endl; // Getting required quantities from likelihood and prior int card = like.get_card(); double data_sum = like.get_data_sum(); double data_sum_squares = like.get_data_sum_squares(); - auto hypers = std::make_shared(prior.get_hypers()); + auto hypers = prior.get_hypers(); + + // std::cout << "current cardinality: " << card << std::endl; // No update possible if (card == 0) { - prior.set_posterior_hypers(*hypers); + prior.set_posterior_hypers(hypers); return; } @@ -41,14 +44,13 @@ void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, Hyperparams::NIG post_params; double y_bar = data_sum / (1.0 * card); // sample mean double ss = data_sum_squares - card * y_bar * y_bar; - post_params.mean = (hypers->var_scaling * hypers->mean + data_sum) / - (hypers->var_scaling + card); - post_params.var_scaling = hypers->var_scaling + card; - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = hypers->scale + 0.5 * ss + - 0.5 * hypers->var_scaling * card * - (y_bar - hypers->mean) * (y_bar - hypers->mean) / - (card + hypers->var_scaling); + post_params.mean = (hypers.var_scaling * hypers.mean + data_sum) / + (hypers.var_scaling + card); + post_params.var_scaling = hypers.var_scaling + card; + post_params.shape = hypers.shape + 0.5 * card; + post_params.scale = hypers.scale + 0.5 * ss + + 0.5 * hypers.var_scaling * card * (y_bar - hypers.mean) * + (y_bar - hypers.mean) / (card + hypers.var_scaling); prior.set_posterior_hypers(post_params); return; From 521cf8f1612e7856567de4d3abc60f0910e55fb2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 14:31:18 +0100 Subject: [PATCH 059/317] Minor code fix --- src/hierarchies/likelihoods/base_likelihood.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 90a32b584..a68f7c823 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -74,9 +74,9 @@ class BaseLikelihood : public AbstractLikelihood { State state; - int card; + int card = 0; - int log_card; + double log_card = stan::math::NEGATIVE_INFTY; std::set cluster_data_idx; }; From 20ce2ac518a49a26649d75c4971af82594677184 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 16:23:18 +0100 Subject: [PATCH 060/317] Use hypers by copy (maybe not ideal) --- src/hierarchies/priors/nig_prior_model.cc | 76 +++++++++++------------ 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index 6f9b6cbb5..7df752685 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -3,34 +3,34 @@ void NIGPriorModel::initialize_hypers() { if (prior->has_fixed_values()) { // Set values - hypers->mean = prior->fixed_values().mean(); - hypers->var_scaling = prior->fixed_values().var_scaling(); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); + hypers.mean = prior->fixed_values().mean(); + hypers.var_scaling = prior->fixed_values().var_scaling(); + hypers.shape = prior->fixed_values().shape(); + hypers.scale = prior->fixed_values().scale(); // Check validity - if (hypers->var_scaling <= 0) { + if (hypers.var_scaling <= 0) { throw std::invalid_argument("Variance-scaling parameter must be > 0"); } - if (hypers->shape <= 0) { + if (hypers.shape <= 0) { throw std::invalid_argument("Shape parameter must be > 0"); } - if (hypers->scale <= 0) { + if (hypers.scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } } else if (prior->has_normal_mean_prior()) { // Set initial values - hypers->mean = prior->normal_mean_prior().mean_prior().mean(); - hypers->var_scaling = prior->normal_mean_prior().var_scaling(); - hypers->shape = prior->normal_mean_prior().shape(); - hypers->scale = prior->normal_mean_prior().scale(); + hypers.mean = prior->normal_mean_prior().mean_prior().mean(); + hypers.var_scaling = prior->normal_mean_prior().var_scaling(); + hypers.shape = prior->normal_mean_prior().shape(); + hypers.scale = prior->normal_mean_prior().scale(); // Check validity - if (hypers->var_scaling <= 0) { + if (hypers.var_scaling <= 0) { throw std::invalid_argument("Variance-scaling parameter must be > 0"); } - if (hypers->shape <= 0) { + if (hypers.shape <= 0) { throw std::invalid_argument("Shape parameter must be > 0"); } - if (hypers->scale <= 0) { + if (hypers.scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } } else if (prior->has_ngg_prior()) { @@ -66,10 +66,10 @@ void NIGPriorModel::initialize_hypers() { throw std::invalid_argument("Shape parameter must be > 0"); } // Set initial values - hypers->mean = mu00; - hypers->var_scaling = alpha00 / beta00; - hypers->shape = alpha0; - hypers->scale = a00 / b00; + hypers.mean = mu00; + hypers.var_scaling = alpha00 / beta00; + hypers.shape = alpha0; + hypers.scale = a00 / b00; } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } @@ -78,16 +78,16 @@ void NIGPriorModel::initialize_hypers() { double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { auto &state = downcast_state(state_).uni_ls_state(); double target = - stan::math::normal_lpdf(state.mean(), hypers->mean, - sqrt(state.var() / hypers->var_scaling)) + - stan::math::inv_gamma_lpdf(state.var(), hypers->shape, hypers->scale); + stan::math::normal_lpdf(state.mean(), hypers.mean, + sqrt(state.var() / hypers.var_scaling)) + + stan::math::inv_gamma_lpdf(state.var(), hypers.shape, hypers.scale); return target; } std::shared_ptr NIGPriorModel::sample( bool use_post_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - Hyperparams::NIG params = use_post_hypers ? *post_hypers : *hypers; + Hyperparams::NIG params = use_post_hypers ? post_hypers : hypers; double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); double mean = @@ -124,7 +124,7 @@ void NIGPriorModel::update_hypers( double mu_n = num / prec; double sig2_n = 1 / prec; // Update hyperparameters with posterior random sampling - hypers->mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); + hypers.mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); } else if (prior->has_ngg_prior()) { // Get hyperparameters: // for mu0 @@ -145,20 +145,20 @@ void NIGPriorModel::update_hypers( double var = st.uni_ls_state().var(); b_n += 1 / var; num += mean / var; - beta_n += (hypers->mean - mean) * (hypers->mean - mean) / var; + beta_n += (hypers.mean - mean) * (hypers.mean - mean) / var; } - double var = hypers->var_scaling * b_n + 1 / sig200; + double var = hypers.var_scaling * b_n + 1 / sig200; b_n += b00; - num = hypers->var_scaling * num + mu00 / sig200; + num = hypers.var_scaling * num + mu00 / sig200; beta_n = beta00 + 0.5 * beta_n; double sig_n = 1 / var; double mu_n = num / var; double alpha_n = alpha00 + 0.5 * states.size(); - double a_n = a00 + states.size() * hypers->shape; + double a_n = a00 + states.size() * hypers.shape; // Update hyperparameters with posterior random Gibbs sampling - hypers->mean = stan::math::normal_rng(mu_n, sig_n, rng); - hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); - hypers->scale = stan::math::gamma_rng(a_n, b_n, rng); + hypers.mean = stan::math::normal_rng(mu_n, sig_n, rng); + hypers.var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); + hypers.scale = stan::math::gamma_rng(a_n, b_n, rng); } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } @@ -167,19 +167,19 @@ void NIGPriorModel::update_hypers( void NIGPriorModel::set_hypers_from_proto( const google::protobuf::Message &hypers_) { auto &hyperscast = downcast_hypers(hypers_).nnig_state(); - hypers->mean = hyperscast.mean(); - hypers->var_scaling = hyperscast.var_scaling(); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); + hypers.mean = hyperscast.mean(); + hypers.var_scaling = hyperscast.var_scaling(); + hypers.scale = hyperscast.scale(); + hypers.shape = hyperscast.shape(); } std::shared_ptr NIGPriorModel::get_hypers_proto() const { bayesmix::NIGDistribution hypers_; - hypers_.set_mean(hypers->mean); - hypers_.set_var_scaling(hypers->var_scaling); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); + hypers_.set_mean(hypers.mean); + hypers_.set_var_scaling(hypers.var_scaling); + hypers_.set_shape(hypers.shape); + hypers_.set_scale(hypers.scale); auto out = std::make_shared(); out->mutable_nnig_state()->CopyFrom(hypers_); From 04c442a35a9705e5228c3409b6e95492b198737b Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 16:24:07 +0100 Subject: [PATCH 061/317] Major bug temp fix --- src/hierarchies/base_hierarchy.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 8cd7b64fa..be18113b8 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -141,11 +141,15 @@ class BaseHierarchy : public AbstractHierarchy { } void sample_prior() override { + int card = like->get_card(); like->set_state_from_proto(*prior->sample(false)); + like->set_card(card); }; void sample_full_cond(bool update_params = false) override { - updater->draw(*like, *prior); + int card = like->get_card(); + updater->draw(*like, *prior, update_params); + like->set_card(card); }; void sample_full_cond( From 5a22adfcde139184df9fc615730bc84f7164aac6 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 16:24:40 +0100 Subject: [PATCH 062/317] Fixed bug --- src/hierarchies/updaters/nnig_updater.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index f771f7245..09fc5a1e9 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -13,9 +13,9 @@ class NNIGUpdater { std::shared_ptr clone() const; bool is_conjugate() const { return true; }; - void draw(UniNormLikelihood& like, NIGPriorModel& prior); + void draw(UniNormLikelihood& like, NIGPriorModel& prior, bool update_params); void initialize(UniNormLikelihood& like, NIGPriorModel& prior); void compute_posterior_hypers(UniNormLikelihood& like, NIGPriorModel& prior); }; -#endif // BAYESMIX_HIERARCHIES_NNIG_UPDATERS_H_ +#endif // BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ From f55135bb496c259250ffebcc3a3858d7c5a316ba Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 16:26:08 +0100 Subject: [PATCH 063/317] Fixed bug --- src/hierarchies/updaters/nnig_updater.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index 23e4063aa..1a4fa7606 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -56,11 +56,14 @@ void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, return; }; -void NNIGUpdater::draw(UniNormLikelihood &like, NIGPriorModel &prior) { +void NNIGUpdater::draw(UniNormLikelihood &like, NIGPriorModel &prior, + bool update_params) { if (like.get_card() == 0) { - like.set_state_from_proto(*prior.sample(true)); + like.set_state_from_proto(*prior.sample(false)); } else { - compute_posterior_hypers(like, prior); + if (update_params) { + compute_posterior_hypers(like, prior); + } like.set_state_from_proto(*prior.sample(true)); } }; From adc967a3839e69241496a87555c5c7f3f524c2cc Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 16:26:50 +0100 Subject: [PATCH 064/317] Make set_card() public --- src/hierarchies/likelihoods/base_likelihood.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index a68f7c823..8e3fa80aa 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -31,6 +31,11 @@ class BaseLikelihood : public AbstractLikelihood { int get_card() const { return card; } + void set_card(const int card_) { + card = card_; + log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); + } + double get_log_card() const { return log_card; } std::set get_data_idx() const { return cluster_data_idx; } @@ -55,11 +60,6 @@ class BaseLikelihood : public AbstractLikelihood { } protected: - void set_card(const int card_) { - card = card_; - log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); - } - bayesmix::AlgorithmState::ClusterState *downcast_state( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< From d6d22c9dc25bf6e3c0aaf2e6a991d237f164de16 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 16:27:33 +0100 Subject: [PATCH 065/317] Use params by copy (maybe go back) --- src/hierarchies/priors/base_prior_model.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 61bb680ed..afc21a0d7 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -25,12 +25,12 @@ class BasePriorModel : public AbstractPriorModel { virtual google::protobuf::Message *get_mutable_prior() override; - HyperParams get_hypers() const { return *hypers; } + HyperParams get_hypers() const { return hypers; } - HyperParams get_posterior_hypers() const { return *post_hypers; } + HyperParams get_posterior_hypers() const { return post_hypers; } void set_posterior_hypers(const HyperParams &_post_hypers) { - post_hypers = std::make_shared(_post_hypers); + post_hypers = _post_hypers; }; void write_hypers_to_proto(google::protobuf::Message *out) const override; @@ -60,8 +60,8 @@ class BasePriorModel : public AbstractPriorModel { const bayesmix::AlgorithmState::ClusterState &>(state_); } - std::shared_ptr hypers = std::make_shared(); - std::shared_ptr post_hypers = std::make_shared(); + HyperParams hypers; + HyperParams post_hypers; std::shared_ptr prior; }; From 38c73fe63ef8a796f1ee4e149dcd4959e17d933b Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 16:27:51 +0100 Subject: [PATCH 066/317] Minor code changes --- src/hierarchies/priors/abstract_prior_model.h | 2 +- src/hierarchies/priors/nig_prior_model.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index f3d4e2702..5ccc7ddfc 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -22,7 +22,7 @@ class AbstractPriorModel { // Da pensare, come restituisco lo stato? magari un pointer? Oppure delego virtual std::shared_ptr sample( - bool use_post_hypers = false) = 0; + bool use_post_hypers) = 0; virtual void update_hypers( const std::vector &states) = 0; diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index bf92a6825..1ffda2007 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -23,7 +23,7 @@ class NIGPriorModel : public BasePriorModel sample( - bool use_post_hypers = false) override; + bool use_post_hypers) override; void update_hypers(const std::vector &states) override; From 63537ad163157cab17ba7b4e2d43af64be21bb24 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 16:39:52 +0100 Subject: [PATCH 067/317] Make notebook more general --- python/notebooks/gaussian_mix_uni.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/notebooks/gaussian_mix_uni.ipynb b/python/notebooks/gaussian_mix_uni.ipynb index 8d368c487..d5a01faf3 100644 --- a/python/notebooks/gaussian_mix_uni.ipynb +++ b/python/notebooks/gaussian_mix_uni.ipynb @@ -19,7 +19,7 @@ "outputs": [], "source": [ "import os\n", - "os.environ[\"BAYESMIX_EXE\"] = \"/Users/marioberaha/dev/bayesmix_origin/build/run_mcmc\"" + "os.environ[\"BAYESMIX_EXE\"] = '../../build/run_mcmc'" ] }, { @@ -294,7 +294,7 @@ "metadata": {}, "outputs": [], "source": [ - "numcluschain, cluschain, bestclus = run_mcmc(\n", + "_ , numcluschain, cluschain, bestclus = run_mcmc(\n", " \"NNIG\", \"DP\", data, g0_params_allprior, dp_params_prior, neal2_algo, \n", " dens_grid=None, return_clusters=True, return_num_clusters=True,\n", " return_best_clus=True)" From 7d4a1241c9f88eef55922b5912213e981ab2bd20 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 19:53:36 +0100 Subject: [PATCH 068/317] Change random seed in test --- test/semi_hdp.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/test/semi_hdp.cc b/test/semi_hdp.cc index b193bea14..830933d9c 100644 --- a/test/semi_hdp.cc +++ b/test/semi_hdp.cc @@ -302,6 +302,7 @@ TEST(semihdp, sample_unique_values2) { } TEST(semihdp, sample_allocations1) { + bayesmix::Rng::Instance().seed(220122); std::vector data(2); data[0] = bayesmix::vstack({Eigen::MatrixXd::Zero(50, 1).array() + 5, From 420e6045ae56fb8e1de632a03f0c1bb80114dcb6 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 19:54:41 +0100 Subject: [PATCH 069/317] Allow only state update from proto --- src/hierarchies/base_hierarchy.h | 10 +++++----- src/hierarchies/updaters/nnig_updater.cc | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index be18113b8..7c921dc94 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -141,15 +141,15 @@ class BaseHierarchy : public AbstractHierarchy { } void sample_prior() override { - int card = like->get_card(); - like->set_state_from_proto(*prior->sample(false)); - like->set_card(card); + // int card = like->get_card(); + like->set_state_from_proto(*prior->sample(false), false); + // like->set_card(card); }; void sample_full_cond(bool update_params = false) override { - int card = like->get_card(); + // int card = like->get_card(); updater->draw(*like, *prior, update_params); - like->set_card(card); + // like->set_card(card); }; void sample_full_cond( diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index 1a4fa7606..be0f82d9c 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -59,11 +59,11 @@ void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, void NNIGUpdater::draw(UniNormLikelihood &like, NIGPriorModel &prior, bool update_params) { if (like.get_card() == 0) { - like.set_state_from_proto(*prior.sample(false)); + like.set_state_from_proto(*prior.sample(false), false); } else { if (update_params) { compute_posterior_hypers(like, prior); } - like.set_state_from_proto(*prior.sample(true)); + like.set_state_from_proto(*prior.sample(true), false); } }; From de26535e5e65b9a1d033abf1c0a6f39fd9cff0cb Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 19:56:00 +0100 Subject: [PATCH 070/317] Allow only state update from proto --- src/hierarchies/likelihoods/uni_norm_likelihood.cc | 4 ++-- src/hierarchies/likelihoods/uni_norm_likelihood.h | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index 9b967e4bb..b0ceada1f 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -16,11 +16,11 @@ void UniNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, } void UniNormLikelihood::set_state_from_proto( - const google::protobuf::Message &state_) { + const google::protobuf::Message &state_, bool update_card) { auto &statecast = downcast_state(state_); state.mean = statecast.uni_ls_state().mean(); state.var = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); + if (update_card) set_card(statecast.cardinality()); } std::shared_ptr diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index e36f9d908..133e1f81e 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -19,7 +19,8 @@ class UniNormLikelihood ~UniNormLikelihood() = default; bool is_multivariate() const override { return false; }; bool is_dependent() const override { return false; }; - void set_state_from_proto(const google::protobuf::Message &state_) override; + void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) override; void clear_summary_statistics() override; double get_data_sum() const { return data_sum; }; double get_data_sum_squares() const { return data_sum_squares; }; From 2ed181617eb6eceffefaf69dc1b64ffe9cd51551 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 19:56:38 +0100 Subject: [PATCH 071/317] Allow only update of state from proto --- src/hierarchies/likelihoods/abstract_likelihood.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index aea01e267..9beef7729 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -42,8 +42,8 @@ class AbstractLikelihood { virtual bool is_dependent() const = 0; - virtual void set_state_from_proto( - const google::protobuf::Message &state_) = 0; + virtual void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) = 0; // IMPLEMENTED in BaseLikelihood virtual void write_state_to_proto(google::protobuf::Message *out) const = 0; From 212ae5d2649d255785d8806e5ad0ee812ae81bcc Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 20:55:00 +0100 Subject: [PATCH 072/317] Make set_card protected --- src/hierarchies/likelihoods/base_likelihood.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 8e3fa80aa..a68f7c823 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -31,11 +31,6 @@ class BaseLikelihood : public AbstractLikelihood { int get_card() const { return card; } - void set_card(const int card_) { - card = card_; - log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); - } - double get_log_card() const { return log_card; } std::set get_data_idx() const { return cluster_data_idx; } @@ -60,6 +55,11 @@ class BaseLikelihood : public AbstractLikelihood { } protected: + void set_card(const int card_) { + card = card_; + log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); + } + bayesmix::AlgorithmState::ClusterState *downcast_state( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< From 3c90a0934a70a36999530751c0639ae3c26f6aa2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 22 Jan 2022 20:57:32 +0100 Subject: [PATCH 073/317] Fixed typo --- python/notebooks/gaussian_mix_uni.ipynb | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/notebooks/gaussian_mix_uni.ipynb b/python/notebooks/gaussian_mix_uni.ipynb index d5a01faf3..40e344f9c 100644 --- a/python/notebooks/gaussian_mix_uni.ipynb +++ b/python/notebooks/gaussian_mix_uni.ipynb @@ -226,7 +226,7 @@ "\n", "`return_clusters=False, return_num_clusters=False, return_best_clus=False`\n", "\n", - "Observe that the number of iterations is extremely small! In real problems, you might want to set the burnin at least to 1000 iterations and the total number of iterations to at leas 2000." + "Observe that the number of iterations is extremely small! In real problems, you might want to set the burnin at least to 1000 iterations and the total number of iterations to at least 2000." ] }, { @@ -325,6 +325,13 @@ " plt.scatter(data_in_clus, np.zeros_like(data_in_clus) + 0.01)\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -332,7 +339,7 @@ "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -346,7 +353,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.10" } }, "nbformat": 4, From 18e8cd078b25f9f4334d2778e6b445912f2e8b78 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 25 Jan 2022 13:58:55 +0100 Subject: [PATCH 074/317] Updaters Hierarichy classes (ONGOING) --- src/hierarchies/updaters/CMakeLists.txt | 2 ++ src/hierarchies/updaters/abstract_updater.h | 23 ++++++++++++++ src/hierarchies/updaters/conjugate_updater.h | 32 ++++++++++++++++++++ 3 files changed, 57 insertions(+) create mode 100644 src/hierarchies/updaters/abstract_updater.h create mode 100644 src/hierarchies/updaters/conjugate_updater.h diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index bb6da2924..616ef93b4 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,5 +1,7 @@ target_sources(bayesmix PUBLIC + # abstract_updater.h + # conjugate_updater.h nnig_updater.h nnig_updater.cc ) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h new file mode 100644 index 000000000..7815dbde1 --- /dev/null +++ b/src/hierarchies/updaters/abstract_updater.h @@ -0,0 +1,23 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ + +// NOT WORKING AT THE MOMENT + +#include + +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" + +class AbstractUpdater { + public: + virtual ~AbstractUpdater() = default; + // virtual std::shared_ptr clone() const = 0; NON CREDO CI + // SERVA + bool is_conjugate() const { return false; }; + virtual void initialize(AbstractLikelihood &like, + AbstractPriorModel &prior) = 0; + virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + bool update_params) = 0; +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ diff --git a/src/hierarchies/updaters/conjugate_updater.h b/src/hierarchies/updaters/conjugate_updater.h new file mode 100644 index 000000000..e39f94efd --- /dev/null +++ b/src/hierarchies/updaters/conjugate_updater.h @@ -0,0 +1,32 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ + +// NOT WORKING AT THE MOMENT + +#include "abstract_updater.h" + +class ConjugateUpdater : public AbstractUpdater { + public: + ConjugateUpdater() = default; + ~ConjugateUpdater() = default; + bool is_conjugate() const override { return true; }; + void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + bool update_params) override; + virtual void compute_posterior_hypers(UniNormLikelihood &like, + NIGPriorModel &prior) = 0; +}; + +void ConjugateUpdater::draw(AbstractLikelihood &like, + AbstractPriorModel &prior, bool update_params) { + bool set_card = true; + if (like.get_card() == 0) { + like.set_state_from_proto(*prior.sample(false), !set_card); + } else { + if (update_params) { + compute_posterior_hypers(like, prior); + } + like.set_state_from_proto(*prior.sample(true), !set_card); + } +} + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ From 77696f1cfba8e1c2745232538ac70fdc6c7c63cb Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 25 Jan 2022 17:36:11 +0100 Subject: [PATCH 075/317] Add source files to target bayesmix --- src/hierarchies/updaters/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 616ef93b4..0b4135292 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,7 +1,7 @@ target_sources(bayesmix PUBLIC - # abstract_updater.h - # conjugate_updater.h + abstract_updater.h + conjugate_updater.h nnig_updater.h nnig_updater.cc ) From 92577ec6acd8f249b9509f945d890a9632a32f4b Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 25 Jan 2022 17:37:13 +0100 Subject: [PATCH 076/317] Now NNIGUpdater derived from ConjugateUpdater --- src/hierarchies/updaters/nnig_updater.cc | 61 ++++++++++-------------- src/hierarchies/updaters/nnig_updater.h | 20 ++++---- 2 files changed, 35 insertions(+), 46 deletions(-) diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index be0f82d9c..c21a2c115 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -1,16 +1,18 @@ #include "nnig_updater.h" -std::shared_ptr NNIGUpdater::clone() const { - auto out = - std::make_shared(static_cast(*this)); - return out; -}; +#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/priors/hyperparams.h" + +void NNIGUpdater::initialize(AbstractLikelihood& like, + AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); -void NNIGUpdater::initialize(UniNormLikelihood &like, NIGPriorModel &prior) { // PriorModel Initialization - prior.initialize(); - Hyperparams::NIG hypers = prior.get_hypers(); - prior.set_posterior_hypers(hypers); + priorcast.initialize(); + auto hypers = priorcast.get_hypers(); + priorcast.set_posterior_hypers(hypers); // State initialization State::UniLS state; @@ -18,25 +20,26 @@ void NNIGUpdater::initialize(UniNormLikelihood &like, NIGPriorModel &prior) { state.var = hypers.scale / (hypers.shape + 1); // Likelihood Initalization - like.set_state(state); - like.clear_data(); - like.clear_summary_statistics(); + likecast.set_state(state); + likecast.clear_data(); + likecast.clear_summary_statistics(); }; -void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, - NIGPriorModel &prior) { - // std::cout << "NNIGUpdater::compute_posterior_hypers()" << std::endl; - // Getting required quantities from likelihood and prior - int card = like.get_card(); - double data_sum = like.get_data_sum(); - double data_sum_squares = like.get_data_sum_squares(); - auto hypers = prior.get_hypers(); +void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); - // std::cout << "current cardinality: " << card << std::endl; + // Getting required quantities from likelihood and prior + int card = likecast.get_card(); + double data_sum = likecast.get_data_sum(); + double data_sum_squares = likecast.get_data_sum_squares(); + auto hypers = priorcast.get_hypers(); // No update possible if (card == 0) { - prior.set_posterior_hypers(hypers); + priorcast.set_posterior_hypers(hypers); return; } @@ -52,18 +55,6 @@ void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, 0.5 * hypers.var_scaling * card * (y_bar - hypers.mean) * (y_bar - hypers.mean) / (card + hypers.var_scaling); - prior.set_posterior_hypers(post_params); + priorcast.set_posterior_hypers(post_params); return; }; - -void NNIGUpdater::draw(UniNormLikelihood &like, NIGPriorModel &prior, - bool update_params) { - if (like.get_card() == 0) { - like.set_state_from_proto(*prior.sample(false), false); - } else { - if (update_params) { - compute_posterior_hypers(like, prior); - } - like.set_state_from_proto(*prior.sample(true), false); - } -}; diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index 09fc5a1e9..60e4de83b 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -1,21 +1,19 @@ -#ifndef BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ -#define BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ -#include "src/hierarchies/likelihoods/states.h" +#include "conjugate_updater.h" #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" -#include "src/hierarchies/priors/hyperparams.h" #include "src/hierarchies/priors/nig_prior_model.h" -class NNIGUpdater { +class NNIGUpdater : public ConjugateUpdater { public: NNIGUpdater() = default; ~NNIGUpdater() = default; - std::shared_ptr clone() const; - bool is_conjugate() const { return true; }; - void draw(UniNormLikelihood& like, NIGPriorModel& prior, bool update_params); - void initialize(UniNormLikelihood& like, NIGPriorModel& prior); - void compute_posterior_hypers(UniNormLikelihood& like, NIGPriorModel& prior); + void initialize(AbstractLikelihood& like, + AbstractPriorModel& prior) override; + void compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) override; }; -#endif // BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ +#endif // BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ From dd90b8c7f363606ecedff0897a0826eefa34758e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 25 Jan 2022 17:40:12 +0100 Subject: [PATCH 077/317] Minor code changes --- src/hierarchies/base_hierarchy.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 7c921dc94..55d1578c2 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -44,7 +44,6 @@ class BaseHierarchy : public AbstractHierarchy { void set_likelihood(std::shared_ptr like_) { like = like_; }; void set_prior(std::shared_ptr prior_) { prior = prior_; }; - void set_updater(std::shared_ptr updater_) { updater = updater_; }; std::shared_ptr clone() const override { // Create copy of the hierarchy @@ -52,7 +51,6 @@ class BaseHierarchy : public AbstractHierarchy { // Cloning each component class out->set_likelihood(std::static_pointer_cast(like->clone())); out->set_prior(std::static_pointer_cast(prior->clone())); - out->set_updater(std::static_pointer_cast(updater->clone())); return out; }; From 2098a423ed7f94c8aa311eb494479271147fdaa8 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 25 Jan 2022 17:40:45 +0100 Subject: [PATCH 078/317] Created inheritance structure for updaters --- src/hierarchies/updaters/abstract_updater.h | 8 +--- src/hierarchies/updaters/conjugate_updater.h | 48 +++++++++++++++----- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 7815dbde1..d89778441 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -1,19 +1,13 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ -// NOT WORKING AT THE MOMENT - -#include - #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" class AbstractUpdater { public: virtual ~AbstractUpdater() = default; - // virtual std::shared_ptr clone() const = 0; NON CREDO CI - // SERVA - bool is_conjugate() const { return false; }; + virtual bool is_conjugate() const { return false; }; virtual void initialize(AbstractLikelihood &like, AbstractPriorModel &prior) = 0; virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, diff --git a/src/hierarchies/updaters/conjugate_updater.h b/src/hierarchies/updaters/conjugate_updater.h index e39f94efd..99552686e 100644 --- a/src/hierarchies/updaters/conjugate_updater.h +++ b/src/hierarchies/updaters/conjugate_updater.h @@ -1,31 +1,57 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ -// NOT WORKING AT THE MOMENT - #include "abstract_updater.h" +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" +template class ConjugateUpdater : public AbstractUpdater { public: ConjugateUpdater() = default; ~ConjugateUpdater() = default; + bool is_conjugate() const override { return true; }; - void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + void draw(AbstractLikelihood& like, AbstractPriorModel& prior, bool update_params) override; - virtual void compute_posterior_hypers(UniNormLikelihood &like, - NIGPriorModel &prior) = 0; + virtual void compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) = 0; + + protected: + Likelihood& downcast_likelihood(AbstractLikelihood& like_); + PriorModel& downcast_prior(AbstractPriorModel& prior_); }; -void ConjugateUpdater::draw(AbstractLikelihood &like, - AbstractPriorModel &prior, bool update_params) { +// Methods' definitions +template +Likelihood& ConjugateUpdater::downcast_likelihood( + AbstractLikelihood& like_) { + return static_cast(like_); +} + +template +PriorModel& ConjugateUpdater::downcast_prior( + AbstractPriorModel& prior_) { + return static_cast(prior_); +} + +template +void ConjugateUpdater::draw(AbstractLikelihood& like, + AbstractPriorModel& prior, + bool update_params) { + // Likelihood and PriorModel downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Sample from the full conditional of a conjugate hierarchy bool set_card = true; - if (like.get_card() == 0) { - like.set_state_from_proto(*prior.sample(false), !set_card); + if (likecast.get_card() == 0) { + likecast.set_state_from_proto(*priorcast.sample(false), !set_card); } else { if (update_params) { - compute_posterior_hypers(like, prior); + compute_posterior_hypers(likecast, priorcast); } - like.set_state_from_proto(*prior.sample(true), !set_card); + likecast.set_state_from_proto(*prior.sample(true), !set_card); } } From 9ac86980fb3d6e13488bec35ca08a625340bf7d5 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:39:00 +0100 Subject: [PATCH 079/317] metropolis and random walk updaters --- src/hierarchies/updaters/CMakeLists.txt | 4 +- src/hierarchies/updaters/abstract_updater.h | 4 ++ src/hierarchies/updaters/metropolis_updater.h | 37 +++++++++++++ .../updaters/random_walk_updater.h | 52 +++++++++++++++++++ 4 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 src/hierarchies/updaters/metropolis_updater.h create mode 100644 src/hierarchies/updaters/random_walk_updater.h diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 616ef93b4..f42810abf 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,7 +1,9 @@ target_sources(bayesmix PUBLIC - # abstract_updater.h + abstract_updater.h # conjugate_updater.h nnig_updater.h nnig_updater.cc + metropolis_updater.h + random_walk_updater.h ) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 7815dbde1..561c2258c 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -18,6 +18,10 @@ class AbstractUpdater { AbstractPriorModel &prior) = 0; virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, bool update_params) = 0; + virtual void compute_posterior_hypers(UniNormLikelihood &like, + NIGPriorModel &prior) { + throw std::runtime_error("compute_posterior_hypers not implemented"); + } }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h new file mode 100644 index 000000000..1a8139f54 --- /dev/null +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -0,0 +1,37 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_METROPOLIS_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_METROPOLIS_UPDATER_H_ + +#include "abstract_updater.h" + +class MetropolisUpdater: public AbstractUpdater { + public: + + void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + bool update_params) override { + Eigen::VectorXd curr_state = like.get_unconstrained_state(); + Eigen::VectorXd prop_state = sample_proposal(curr_state, like, prior); + + + double log_arate = like.cluster_lpdf_from_unconstrained(prop_state) - + like.cluster_lpdf_from_unconstrained(curr_state) + + proposal_lpdf(curr_state, prop_state, like, prior) - + proposal_lpdf(prop_state, curr_state, like, prior); + + auto& rng = bayesmix::Rng::Instance().get(); + if (std::log(stan::math::uniform_rng(0, 1, rng)) < log_arate) { + like.set_state_from_unconstrained(prop_state); + } + } + + virtual Eigen::VectorXd sample_proposal( + Eigen::VectorXd curr_state, AbstractLikelihood &like, + AbstractPriorModel &prior) = 0; + + virtual double proposal_lpdf( + Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, + AbstractLikelihood &like, AbstractPriorModel &prior) = 0; + + +}; + +#endif diff --git a/src/hierarchies/updaters/random_walk_updater.h b/src/hierarchies/updaters/random_walk_updater.h new file mode 100644 index 000000000..a35ed1e10 --- /dev/null +++ b/src/hierarchies/updaters/random_walk_updater.h @@ -0,0 +1,52 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_RANDOM_WALK_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_RANDOM_WALK_UPDATER_H_ + +#include "metropolis_updater.h" + +class RandomWalkUpdater: public MetropolisUpdater { + protected: + double step_size; + + public: + RandomWalkUpdater() = default; + ~RandomWalkUpdater() = default; + + RandomWalkUpdater(double step_size): step_size(step_size) {} + + Eigen::VectorXd sample_proposal( + Eigen::VectorXd curr_state, AbstractLikelihood &like, + AbstractPriorModel &prior) override { + Eigen::VectorXd step(curr_state.size()); + auto& rng = bayesmix::Rng::Instance().get(); + for (int i=0; i < curr_state.size(); i++) { + step(i) = stan::math::normal_rng(0, step_size, rng); + } + return curr_state + step; + } + + double proposal_lpdf( + Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, + AbstractLikelihood &like, AbstractPriorModel &prior) override { + double out; + for (int i=0; i < prop_state.size(); i++) { + out += stan::math::normal_lpdf(prop_state(i), curr_state(i), step_size); + } + return out; + } + + void initialize(AbstractLikelihood& like, AbstractPriorModel& prior) override { + // Weird to have it here!! + // prior.initialize(); + Eigen::VectorXd curr_state = Eigen::VectorXd::Ones(2); + like.set_state_from_unconstrained(curr_state); + } + + std::shared_ptr clone() const { + auto out = + std::make_shared(static_cast(*this)); + return out; + } + +}; + +#endif \ No newline at end of file From d563a7950e844db4f7bc2f46948af3f749783885 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:39:36 +0100 Subject: [PATCH 080/317] working with unconstrained parameters --- src/hierarchies/likelihoods/CMakeLists.txt | 10 ++---- .../likelihoods/abstract_likelihood.h | 17 +++++++++ src/hierarchies/likelihoods/base_likelihood.h | 9 +++++ src/hierarchies/likelihoods/states.h | 35 ++++++++++++++++++- .../likelihoods/uni_norm_likelihood.cc | 11 ++++++ .../likelihoods/uni_norm_likelihood.h | 4 +++ 6 files changed, 77 insertions(+), 9 deletions(-) diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 3e5495073..7444c0116 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -1,8 +1,2 @@ -target_sources(bayesmix - PUBLIC - abstract_likelihood.h - base_likelihood.h - states.h - uni_norm_likelihood.h - uni_norm_likelihood.cc -) +target_sources(bayesmix PUBLIC abstract_likelihood.h base_likelihood.h states + .h uni_norm_likelihood.h uni_norm_likelihood.cc) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 9beef7729..ec3d5ce75 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -31,6 +31,18 @@ class AbstractLikelihood { } } + //! Evaluates the log likelihood over all the data in the cluster + //! given unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. + virtual double cluster_lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) { + throw std::runtime_error( + "cluster_lpdf_from_unconstrained() not yet implemented"); + } + virtual Eigen::VectorXd lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const = 0; @@ -45,6 +57,9 @@ class AbstractLikelihood { virtual void set_state_from_proto(const google::protobuf::Message &state_, bool update_card = true) = 0; + virtual void set_state_from_unconstrained( + const Eigen::VectorXd &unconstrained_state) = 0; + // IMPLEMENTED in BaseLikelihood virtual void write_state_to_proto(google::protobuf::Message *out) const = 0; @@ -70,6 +85,8 @@ class AbstractLikelihood { virtual void clear_summary_statistics() = 0; + virtual Eigen::VectorXd get_unconstrained_state() = 0; + protected: virtual std::shared_ptr get_state_proto() const = 0; diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index a68f7c823..3297114dd 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -39,8 +39,17 @@ class BaseLikelihood : public AbstractLikelihood { State get_state() const { return state; } + Eigen::VectorXd get_unconstrained_state() override { + return state.get_unconstrained(); + } + void set_state(const State &_state) { state = _state; }; + void set_state_from_unconstrained( + const Eigen::VectorXd &unconstrained_state) override { + state.set_from_unconstrained(unconstrained_state); + } + void add_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states.h index 70dc7f032..650fe0003 100644 --- a/src/hierarchies/likelihoods/states.h +++ b/src/hierarchies/likelihoods/states.h @@ -2,11 +2,44 @@ #define BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ #include +#include + +#include "algorithm_state.pb.h" namespace State { -struct UniLS { +class UniLS { + public: double mean, var; + + Eigen::VectorXd get_unconstrained() { + Eigen::VectorXd out(2); + out << mean, std::log(var); + return out; + } + + void set_from_unconstrained(Eigen::VectorXd in) { + mean = in(0); + var = std::exp(in(1)); + } + + void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + mean = state_.uni_ls_state().mean(); + var = state_.uni_ls_state().var(); + } + + bayesmix::AlgorithmState::ClusterState get_as_proto() { + bayesmix::AlgorithmState::ClusterState state; + state.mutable_uni_ls_state()->set_mean(mean); + state.mutable_uni_ls_state()->set_var(var); + return state; + } + + double log_det_jac() { + double out = 0; + stan::math::positive_constrain(std::log(var), out); + return out; + } }; struct MultiLS { diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index b0ceada1f..ce5419116 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -35,3 +35,14 @@ void UniNormLikelihood::clear_summary_statistics() { data_sum = 0; data_sum_squares = 0; } + +double UniNormLikelihood::cluster_lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) { + assert(unconstrained_params.size() == 2); + double mean = unconstrained_params(0); + double var = std::exp(unconstrained_params(1)); + double out = -(data_sum_squares - 2 * mean * data_sum + card * mean * mean) / + (2 * var); + out -= card * 0.5 * std::log(stan::math::TWO_PI * var); + return out; +} diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 133e1f81e..5ce0ed9c2 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -25,6 +25,10 @@ class UniNormLikelihood double get_data_sum() const { return data_sum; }; double get_data_sum_squares() const { return data_sum_squares; }; + // The unconstrained parameters are mean and log(var) + double cluster_lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) override; + protected: std::shared_ptr get_state_proto() const override; From eb7c7f35cc4076d49370c4b8160e0711cadd5e39 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:40:03 +0100 Subject: [PATCH 081/317] unconstrained parameters in priors --- src/hierarchies/priors/CMakeLists.txt | 10 ++-------- src/hierarchies/priors/abstract_prior_model.h | 10 ++++++++++ src/hierarchies/priors/nig_prior_model.cc | 7 +++++++ src/hierarchies/priors/nig_prior_model.h | 3 +++ 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt index 8547a6acf..726c08f49 100644 --- a/src/hierarchies/priors/CMakeLists.txt +++ b/src/hierarchies/priors/CMakeLists.txt @@ -1,8 +1,2 @@ -target_sources(bayesmix - PUBLIC - abstract_prior_model.h - base_prior_model.h - hyperparams.h - nig_prior_model.h - nig_prior_model.cc -) +target_sources(bayesmix PUBLIC abstract_prior_model.h base_prior_model + .h hyperparams.h nig_prior_model.h nig_prior_model.cc) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 5ccc7ddfc..2852a748e 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -20,6 +20,16 @@ class AbstractPriorModel { virtual double lpdf(const google::protobuf::Message &state_) = 0; + //! Evaluates the log likelihood for unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. + virtual double lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) { + throw std::runtime_error("lpdf_from_unconstrained() not yet implemented"); + } + // Da pensare, come restituisco lo stato? magari un pointer? Oppure delego virtual std::shared_ptr sample( bool use_post_hypers) = 0; diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index 7df752685..d2eab6991 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -84,6 +84,13 @@ double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } +double NIGPriorModel::lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) { + State::UniLS state; + state.set_from_unconstrained(unconstrained_params); + return lpdf(state.get_as_proto()) + state.log_det_jac(); +} + std::shared_ptr NIGPriorModel::sample( bool use_post_hypers) { auto &rng = bayesmix::Rng::Instance().get(); diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index 1ffda2007..dcdd327df 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -22,6 +22,9 @@ class NIGPriorModel : public BasePriorModel sample( bool use_post_hypers) override; From 8a5ad95df0a28bdffe65bd976fd906029a2e734c Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:40:31 +0100 Subject: [PATCH 082/317] clang --- src/hierarchies/updaters/CMakeLists.txt | 14 ++- src/hierarchies/updaters/abstract_updater.h | 2 +- src/hierarchies/updaters/metropolis_updater.h | 55 ++++++------ .../updaters/random_walk_updater.h | 90 +++++++++---------- 4 files changed, 77 insertions(+), 84 deletions(-) diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index f42810abf..e3536da34 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,9 +1,5 @@ -target_sources(bayesmix - PUBLIC - abstract_updater.h - # conjugate_updater.h - nnig_updater.h - nnig_updater.cc - metropolis_updater.h - random_walk_updater.h -) +target_sources(bayesmix PUBLIC abstract_updater + .h +#conjugate_updater.h + nnig_updater.h nnig_updater.cc metropolis_updater + .h random_walk_updater.h) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 561c2258c..f8e3f7d53 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -19,7 +19,7 @@ class AbstractUpdater { virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, bool update_params) = 0; virtual void compute_posterior_hypers(UniNormLikelihood &like, - NIGPriorModel &prior) { + NIGPriorModel &prior) { throw std::runtime_error("compute_posterior_hypers not implemented"); } }; diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h index 1a8139f54..b328cf7d4 100644 --- a/src/hierarchies/updaters/metropolis_updater.h +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -3,35 +3,32 @@ #include "abstract_updater.h" -class MetropolisUpdater: public AbstractUpdater { - public: - - void draw(AbstractLikelihood &like, AbstractPriorModel &prior, - bool update_params) override { - Eigen::VectorXd curr_state = like.get_unconstrained_state(); - Eigen::VectorXd prop_state = sample_proposal(curr_state, like, prior); - - - double log_arate = like.cluster_lpdf_from_unconstrained(prop_state) - - like.cluster_lpdf_from_unconstrained(curr_state) + - proposal_lpdf(curr_state, prop_state, like, prior) - - proposal_lpdf(prop_state, curr_state, like, prior); - - auto& rng = bayesmix::Rng::Instance().get(); - if (std::log(stan::math::uniform_rng(0, 1, rng)) < log_arate) { - like.set_state_from_unconstrained(prop_state); - } - } - - virtual Eigen::VectorXd sample_proposal( - Eigen::VectorXd curr_state, AbstractLikelihood &like, - AbstractPriorModel &prior) = 0; - - virtual double proposal_lpdf( - Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, - AbstractLikelihood &like, AbstractPriorModel &prior) = 0; - - +class MetropolisUpdater : public AbstractUpdater { + public: + void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + bool update_params) override { + Eigen::VectorXd curr_state = like.get_unconstrained_state(); + Eigen::VectorXd prop_state = sample_proposal(curr_state, like, prior); + + double log_arate = like.cluster_lpdf_from_unconstrained(prop_state) - + like.cluster_lpdf_from_unconstrained(curr_state) + + proposal_lpdf(curr_state, prop_state, like, prior) - + proposal_lpdf(prop_state, curr_state, like, prior); + + auto &rng = bayesmix::Rng::Instance().get(); + if (std::log(stan::math::uniform_rng(0, 1, rng)) < log_arate) { + like.set_state_from_unconstrained(prop_state); + } + } + + virtual Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, + AbstractLikelihood &like, + AbstractPriorModel &prior) = 0; + + virtual double proposal_lpdf(Eigen::VectorXd prop_state, + Eigen::VectorXd curr_state, + AbstractLikelihood &like, + AbstractPriorModel &prior) = 0; }; #endif diff --git a/src/hierarchies/updaters/random_walk_updater.h b/src/hierarchies/updaters/random_walk_updater.h index a35ed1e10..388722215 100644 --- a/src/hierarchies/updaters/random_walk_updater.h +++ b/src/hierarchies/updaters/random_walk_updater.h @@ -3,50 +3,50 @@ #include "metropolis_updater.h" -class RandomWalkUpdater: public MetropolisUpdater { - protected: - double step_size; - - public: - RandomWalkUpdater() = default; - ~RandomWalkUpdater() = default; - - RandomWalkUpdater(double step_size): step_size(step_size) {} - - Eigen::VectorXd sample_proposal( - Eigen::VectorXd curr_state, AbstractLikelihood &like, - AbstractPriorModel &prior) override { - Eigen::VectorXd step(curr_state.size()); - auto& rng = bayesmix::Rng::Instance().get(); - for (int i=0; i < curr_state.size(); i++) { - step(i) = stan::math::normal_rng(0, step_size, rng); - } - return curr_state + step; - } - - double proposal_lpdf( - Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, - AbstractLikelihood &like, AbstractPriorModel &prior) override { - double out; - for (int i=0; i < prop_state.size(); i++) { - out += stan::math::normal_lpdf(prop_state(i), curr_state(i), step_size); - } - return out; - } - - void initialize(AbstractLikelihood& like, AbstractPriorModel& prior) override { - // Weird to have it here!! - // prior.initialize(); - Eigen::VectorXd curr_state = Eigen::VectorXd::Ones(2); - like.set_state_from_unconstrained(curr_state); - } - - std::shared_ptr clone() const { - auto out = - std::make_shared(static_cast(*this)); - return out; - } - +class RandomWalkUpdater : public MetropolisUpdater { + protected: + double step_size; + + public: + RandomWalkUpdater() = default; + ~RandomWalkUpdater() = default; + + RandomWalkUpdater(double step_size) : step_size(step_size) {} + + Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, + AbstractLikelihood &like, + AbstractPriorModel &prior) override { + Eigen::VectorXd step(curr_state.size()); + auto &rng = bayesmix::Rng::Instance().get(); + for (int i = 0; i < curr_state.size(); i++) { + step(i) = stan::math::normal_rng(0, step_size, rng); + } + return curr_state + step; + } + + double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, + AbstractLikelihood &like, + AbstractPriorModel &prior) override { + double out; + for (int i = 0; i < prop_state.size(); i++) { + out += stan::math::normal_lpdf(prop_state(i), curr_state(i), step_size); + } + return out; + } + + void initialize(AbstractLikelihood &like, + AbstractPriorModel &prior) override { + // Weird to have it here!! + // prior.initialize(); + Eigen::VectorXd curr_state = Eigen::VectorXd::Ones(2); + like.set_state_from_unconstrained(curr_state); + } + + std::shared_ptr clone() const { + auto out = std::make_shared( + static_cast(*this)); + return out; + } }; -#endif \ No newline at end of file +#endif From 7b19cb5ed4e969753f29cb8cbeaaf03140de801e Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:40:59 +0100 Subject: [PATCH 083/317] tests for unconstrained stuff --- test/CMakeLists.txt | 22 +++++++++++----------- test/eigenproto.cc | 15 +++++++++++++++ test/likelihoods.cc | 34 ++++++++++++++++++++++++++++++++++ test/mfm_mixing.cc | 9 +++++++++ test/prior_models.cc | 27 +++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 11 deletions(-) create mode 100644 test/eigenproto.cc create mode 100644 test/mfm_mixing.cc diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ea406f1b3..22118d049 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,20 +16,20 @@ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) add_executable(test_bayesmix $ - write_proto.cc - proto_utils.cc + # write_proto.cc + # proto_utils.cc likelihoods.cc prior_models.cc - hierarchies.cc - lpdf.cc + # hierarchies.cc + # lpdf.cc # priors.cc // OLD, USEREI prior_models.cc - eigen_utils.cc - distributions.cc - semi_hdp.cc - collectors.cc - runtime.cc - rng.cc - logit_sb.cc + # eigen_utils.cc + # distributions.cc + # semi_hdp.cc + # collectors.cc + # runtime.cc + # rng.cc + # logit_sb.cc ) target_include_directories(test_bayesmix PUBLIC ${INCLUDE_PATHS}) diff --git a/test/eigenproto.cc b/test/eigenproto.cc new file mode 100644 index 000000000..e28183dc9 --- /dev/null +++ b/test/eigenproto.cc @@ -0,0 +1,15 @@ +#include "src/utils/eigenproto.h" + +#include + + +TEST(eigenproto, vector_conversion) { + Eigen::VectorXd v = Eigen::VectorXd::Random(3); + MyVectorType mv(v); + bayesmix::Vector vproto; + bayesmix::to_proto(mv, &vproto); + + bayesmix::Vector mvproto = mv; + + EXPECT_EQ(vproto.DebugString(), mvproto.DebugString()); +} \ No newline at end of file diff --git a/test/likelihoods.cc b/test/likelihoods.cc index d5af2a024..eb7c3f125 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -74,3 +74,37 @@ TEST(uni_norm_likelihood, eval_lpdf) { // Check if they coincides ASSERT_EQ(evals, evals_copy); } + +TEST(uni_norm_likelihood, eval_lpdf_unconstrained) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + double mean = 5; + double var = 1; + state_.set_mean(mean); + state_.set_var(var); + Eigen::VectorXd unconstrained_params(2); + unconstrained_params << mean, std::log(var); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 4.5, 5.1, 2.5; + double lpdf = 0.0; + for (int i = 0; i < data.size(); ++i) { + like->add_datum(i, data.row(i)); + lpdf += like->lpdf(data.row(i)); + } + + double clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_DOUBLE_EQ(lpdf, clus_lpdf); + + unconstrained_params(0) = 4.0; + clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5); +} + diff --git a/test/mfm_mixing.cc b/test/mfm_mixing.cc new file mode 100644 index 000000000..16fb815c1 --- /dev/null +++ b/test/mfm_mixing.cc @@ -0,0 +1,9 @@ +#include + +#include +#include + +TEST(mfm_mixing, mfm_mixing_test) { + ASSERT_EQ(1, 1); + ASSERT_GT(0, 1); +} \ No newline at end of file diff --git a/test/prior_models.cc b/test/prior_models.cc index 3c55ebd70..2db111664 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -125,3 +125,30 @@ TEST(nig_prior_model, sample) { // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); } + +TEST(nig_prior_model, unconstrained_lpdf) { + // Instance + auto prior = std::make_shared(); + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + hypers_proto.mutable_nnig_state()->set_mean(5.0); + hypers_proto.mutable_nnig_state()->set_var_scaling(0.1); + hypers_proto.mutable_nnig_state()->set_shape(4.0); + hypers_proto.mutable_nnig_state()->set_scale(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + double mean = 0.0; + double var = 5.0; + + bayesmix::AlgorithmState::ClusterState state; + state.mutable_uni_ls_state()->set_mean(mean); + state.mutable_uni_ls_state()->set_var(var); + Eigen::VectorXd unconstrained_params(2); + unconstrained_params << mean, std::log(var); + + ASSERT_DOUBLE_EQ(prior->lpdf(state) + std::log(std::exp(unconstrained_params(1))), + prior->lpdf_from_unconstrained(unconstrained_params)); +} + From 0b2e12ca0d1cfd81764d2e94cf9c5ef82a4dd2cf Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:41:25 +0100 Subject: [PATCH 084/317] removed stuff --- test/eigenproto.cc | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 test/eigenproto.cc diff --git a/test/eigenproto.cc b/test/eigenproto.cc deleted file mode 100644 index e28183dc9..000000000 --- a/test/eigenproto.cc +++ /dev/null @@ -1,15 +0,0 @@ -#include "src/utils/eigenproto.h" - -#include - - -TEST(eigenproto, vector_conversion) { - Eigen::VectorXd v = Eigen::VectorXd::Random(3); - MyVectorType mv(v); - bayesmix::Vector vproto; - bayesmix::to_proto(mv, &vproto); - - bayesmix::Vector mvproto = mv; - - EXPECT_EQ(vproto.DebugString(), mvproto.DebugString()); -} \ No newline at end of file From 9131cb756356384cb0485c8727f49ec5480dee59 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:43:36 +0100 Subject: [PATCH 085/317] more tests --- test/likelihoods.cc | 4 ++-- test/prior_models.cc | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/likelihoods.cc b/test/likelihoods.cc index eb7c3f125..c0bdfb276 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -100,11 +100,11 @@ TEST(uni_norm_likelihood, eval_lpdf_unconstrained) { lpdf += like->lpdf(data.row(i)); } - double clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); + double clus_lpdf = + like->cluster_lpdf_from_unconstrained(unconstrained_params); ASSERT_DOUBLE_EQ(lpdf, clus_lpdf); unconstrained_params(0) = 4.0; clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5); } - diff --git a/test/prior_models.cc b/test/prior_models.cc index 2db111664..0ef122c98 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -148,7 +148,7 @@ TEST(nig_prior_model, unconstrained_lpdf) { Eigen::VectorXd unconstrained_params(2); unconstrained_params << mean, std::log(var); - ASSERT_DOUBLE_EQ(prior->lpdf(state) + std::log(std::exp(unconstrained_params(1))), - prior->lpdf_from_unconstrained(unconstrained_params)); + ASSERT_DOUBLE_EQ( + prior->lpdf(state) + std::log(std::exp(unconstrained_params(1))), + prior->lpdf_from_unconstrained(unconstrained_params)); } - From 61abfb123eb435850854804cba81edaf93d53b13 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:43:56 +0100 Subject: [PATCH 086/317] temporary executable --- CMakeLists.txt | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1dbcd2174..4ba0d0eee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ project(bayesmix) set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) +set(CMAKE_BUILD_TYPE Debug) set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) set(CMAKE_CXX_FLAGS_RELEASE "-O3 -funroll-loops -ftree-vectorize") @@ -169,4 +169,9 @@ if (NOT DISABLE_DOCS) add_subdirectory(docs) endif() -add_subdirectory(examples) +add_executable(test_mh $ test_mh_updater.cpp) +target_include_directories(test_mh PUBLIC ${INCLUDE_PATHS}) +target_link_libraries(test_mh PUBLIC ${LINK_LIBRARIES}) +target_compile_options(test_mh PUBLIC ${COMPILE_OPTIONS}) + +# add_subdirectory(examples) From a190af4d7778718c8c6aff1ad5ec75c6cca02d51 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 25 Jan 2022 19:44:48 +0100 Subject: [PATCH 087/317] nnig with rw updater --- src/hierarchies/nnig_hierarchy.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index c2e7aa0ab..dbd841020 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -16,9 +16,10 @@ #include "likelihoods/uni_norm_likelihood.h" #include "priors/nig_prior_model.h" #include "updaters/nnig_updater.h" +#include "updaters/random_walk_updater.h" class NNIGHierarchy : public BaseHierarchy { + NIGPriorModel, RandomWalkUpdater> { public: NNIGHierarchy() = default; ~NNIGHierarchy() = default; From a69c570299803828efdd91e6aa72eacf2e517735 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 12:32:07 +0100 Subject: [PATCH 088/317] Add hierarchy structure for updaters --- src/hierarchies/updaters/CMakeLists.txt | 4 +- src/hierarchies/updaters/abstract_updater.h | 10 +--- src/hierarchies/updaters/conjugate_updater.h | 48 +++++++++++++++----- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 616ef93b4..0b4135292 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,7 +1,7 @@ target_sources(bayesmix PUBLIC - # abstract_updater.h - # conjugate_updater.h + abstract_updater.h + conjugate_updater.h nnig_updater.h nnig_updater.cc ) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 7815dbde1..295e03ab0 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -1,21 +1,13 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ -// NOT WORKING AT THE MOMENT - -#include - #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" class AbstractUpdater { public: virtual ~AbstractUpdater() = default; - // virtual std::shared_ptr clone() const = 0; NON CREDO CI - // SERVA - bool is_conjugate() const { return false; }; - virtual void initialize(AbstractLikelihood &like, - AbstractPriorModel &prior) = 0; + virtual bool is_conjugate() const { return false; }; virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, bool update_params) = 0; }; diff --git a/src/hierarchies/updaters/conjugate_updater.h b/src/hierarchies/updaters/conjugate_updater.h index e39f94efd..99552686e 100644 --- a/src/hierarchies/updaters/conjugate_updater.h +++ b/src/hierarchies/updaters/conjugate_updater.h @@ -1,31 +1,57 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ -// NOT WORKING AT THE MOMENT - #include "abstract_updater.h" +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" +template class ConjugateUpdater : public AbstractUpdater { public: ConjugateUpdater() = default; ~ConjugateUpdater() = default; + bool is_conjugate() const override { return true; }; - void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + void draw(AbstractLikelihood& like, AbstractPriorModel& prior, bool update_params) override; - virtual void compute_posterior_hypers(UniNormLikelihood &like, - NIGPriorModel &prior) = 0; + virtual void compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) = 0; + + protected: + Likelihood& downcast_likelihood(AbstractLikelihood& like_); + PriorModel& downcast_prior(AbstractPriorModel& prior_); }; -void ConjugateUpdater::draw(AbstractLikelihood &like, - AbstractPriorModel &prior, bool update_params) { +// Methods' definitions +template +Likelihood& ConjugateUpdater::downcast_likelihood( + AbstractLikelihood& like_) { + return static_cast(like_); +} + +template +PriorModel& ConjugateUpdater::downcast_prior( + AbstractPriorModel& prior_) { + return static_cast(prior_); +} + +template +void ConjugateUpdater::draw(AbstractLikelihood& like, + AbstractPriorModel& prior, + bool update_params) { + // Likelihood and PriorModel downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Sample from the full conditional of a conjugate hierarchy bool set_card = true; - if (like.get_card() == 0) { - like.set_state_from_proto(*prior.sample(false), !set_card); + if (likecast.get_card() == 0) { + likecast.set_state_from_proto(*priorcast.sample(false), !set_card); } else { if (update_params) { - compute_posterior_hypers(like, prior); + compute_posterior_hypers(likecast, priorcast); } - like.set_state_from_proto(*prior.sample(true), !set_card); + likecast.set_state_from_proto(*prior.sample(true), !set_card); } } From 7d1b9a4a23a36a690542f064e203b113cc61c9c3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 13:17:26 +0100 Subject: [PATCH 089/317] initialize() now in charge to hierarchy classes --- src/hierarchies/base_hierarchy.h | 13 +++++++++++-- src/hierarchies/nnig_hierarchy.h | 10 ++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 7c921dc94..734aa227a 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -52,7 +52,6 @@ class BaseHierarchy : public AbstractHierarchy { // Cloning each component class out->set_likelihood(std::static_pointer_cast(like->clone())); out->set_prior(std::static_pointer_cast(prior->clone())); - out->set_updater(std::static_pointer_cast(updater->clone())); return out; }; @@ -231,7 +230,14 @@ class BaseHierarchy : public AbstractHierarchy { if (update_params) updater->compute_posterior_hypers(*like, *prior); }; - void initialize() override { updater->initialize(*like, *prior); }; + void initialize() override { + prior->initialize(); + if (is_conjugate()) + prior->set_posterior_hypers(prior->get_hypers()); + initialize_state(); + like->clear_data(); + like->clear_summary_statistics(); + }; bool is_multivariate() const override { return like->is_multivariate(); }; @@ -240,6 +246,9 @@ class BaseHierarchy : public AbstractHierarchy { bool is_conjugate() const override { return updater->is_conjugate(); }; protected: + + virtual void initialize_state() = 0; + virtual double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum) const { if (!is_conjugate()) { diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index c2e7aa0ab..7903d5353 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -27,6 +27,16 @@ class NNIGHierarchy : public BaseHierarchyget_hypers(); + // Initialize likelihood state + State::UniLS state; + state.mean = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + like->set_state(state); + }; + double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum) const override { double sig_n = sqrt(params.scale * (params.var_scaling + 1) / From 9ddfaba175e236b3625a647095b928614a69058d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 13:18:35 +0100 Subject: [PATCH 090/317] NNIGUpdate is now a derived class --- src/hierarchies/updaters/nnig_updater.cc | 57 ++++++------------------ src/hierarchies/updaters/nnig_updater.h | 18 +++----- 2 files changed, 20 insertions(+), 55 deletions(-) diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index be0f82d9c..d05f19c83 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -1,42 +1,23 @@ #include "nnig_updater.h" -std::shared_ptr NNIGUpdater::clone() const { - auto out = - std::make_shared(static_cast(*this)); - return out; -}; - -void NNIGUpdater::initialize(UniNormLikelihood &like, NIGPriorModel &prior) { - // PriorModel Initialization - prior.initialize(); - Hyperparams::NIG hypers = prior.get_hypers(); - prior.set_posterior_hypers(hypers); +#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/priors/hyperparams.h" - // State initialization - State::UniLS state; - state.mean = hypers.mean; - state.var = hypers.scale / (hypers.shape + 1); - - // Likelihood Initalization - like.set_state(state); - like.clear_data(); - like.clear_summary_statistics(); -}; +void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); -void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, - NIGPriorModel &prior) { - // std::cout << "NNIGUpdater::compute_posterior_hypers()" << std::endl; // Getting required quantities from likelihood and prior - int card = like.get_card(); - double data_sum = like.get_data_sum(); - double data_sum_squares = like.get_data_sum_squares(); - auto hypers = prior.get_hypers(); - - // std::cout << "current cardinality: " << card << std::endl; + int card = likecast.get_card(); + double data_sum = likecast.get_data_sum(); + double data_sum_squares = likecast.get_data_sum_squares(); + auto hypers = priorcast.get_hypers(); // No update possible if (card == 0) { - prior.set_posterior_hypers(hypers); + priorcast.set_posterior_hypers(hypers); return; } @@ -52,18 +33,6 @@ void NNIGUpdater::compute_posterior_hypers(UniNormLikelihood &like, 0.5 * hypers.var_scaling * card * (y_bar - hypers.mean) * (y_bar - hypers.mean) / (card + hypers.var_scaling); - prior.set_posterior_hypers(post_params); + priorcast.set_posterior_hypers(post_params); return; }; - -void NNIGUpdater::draw(UniNormLikelihood &like, NIGPriorModel &prior, - bool update_params) { - if (like.get_card() == 0) { - like.set_state_from_proto(*prior.sample(false), false); - } else { - if (update_params) { - compute_posterior_hypers(like, prior); - } - like.set_state_from_proto(*prior.sample(true), false); - } -}; diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index 09fc5a1e9..02d6bdce8 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -1,21 +1,17 @@ -#ifndef BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ -#define BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ -#include "src/hierarchies/likelihoods/states.h" +#include "conjugate_updater.h" #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" -#include "src/hierarchies/priors/hyperparams.h" #include "src/hierarchies/priors/nig_prior_model.h" -class NNIGUpdater { +class NNIGUpdater : public ConjugateUpdater { public: NNIGUpdater() = default; ~NNIGUpdater() = default; - std::shared_ptr clone() const; - bool is_conjugate() const { return true; }; - void draw(UniNormLikelihood& like, NIGPriorModel& prior, bool update_params); - void initialize(UniNormLikelihood& like, NIGPriorModel& prior); - void compute_posterior_hypers(UniNormLikelihood& like, NIGPriorModel& prior); + void compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) override; }; -#endif // BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ +#endif // BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ From 4c8f0bb78b71395d8bc8f3a3b6c0810f22c30468 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Wed, 26 Jan 2022 17:08:50 +0100 Subject: [PATCH 091/317] constructor --- resources/tutorial/algo.asciipb | 2 +- run_mcmc.cc | 7 +++++ src/algorithms/neal8_algorithm.h | 2 ++ src/hierarchies/abstract_hierarchy.h | 7 +++++ src/hierarchies/base_hierarchy.h | 30 +++++++++++++++++---- src/hierarchies/likelihoods/CMakeLists.txt | 9 +++++-- src/hierarchies/nnig_hierarchy.h | 6 +++-- src/hierarchies/priors/CMakeLists.txt | 9 +++++-- src/hierarchies/priors/nig_prior_model.cc | 3 +++ src/hierarchies/updaters/CMakeLists.txt | 13 +++++---- src/hierarchies/updaters/abstract_updater.h | 4 +-- src/hierarchies/updaters/nnig_updater.h | 3 ++- 12 files changed, 75 insertions(+), 20 deletions(-) diff --git a/resources/tutorial/algo.asciipb b/resources/tutorial/algo.asciipb index d748858d4..027ac8d9f 100644 --- a/resources/tutorial/algo.asciipb +++ b/resources/tutorial/algo.asciipb @@ -1,6 +1,6 @@ ##### GENERIC SETTINGS FOR ALL ALGORITHMS ##### # Algorithm ID string, e.g. "Neal2" -algo_id: "Neal3" +algo_id: "Neal8" # RNG initial seed: any nonnegative integer rng_seed: 20201124 diff --git a/run_mcmc.cc b/run_mcmc.cc index 8daada47b..93f00d637 100644 --- a/run_mcmc.cc +++ b/run_mcmc.cc @@ -177,6 +177,13 @@ int main(int argc, char *argv[]) { bayesmix::read_proto_from_file(args.get("--hier-args"), hier->get_mutable_prior()); + std::cout << "hier->prior: \n" + << hier->get_mutable_prior()->DebugString() << std::endl; + + auto updater = std::make_shared(0.25); + hier->set_updater(updater); + hier->initialize(); + // Read data matrices Eigen::MatrixXd data = bayesmix::read_eigen_matrix(args.get("--data-file")); diff --git a/src/algorithms/neal8_algorithm.h b/src/algorithms/neal8_algorithm.h index 55479db1a..11ad0048f 100644 --- a/src/algorithms/neal8_algorithm.h +++ b/src/algorithms/neal8_algorithm.h @@ -41,6 +41,8 @@ class Neal8Algorithm : public Neal2Algorithm { void read_params_from_proto( const bayesmix::AlgorithmParams ¶ms) override; + bool requires_conjugate_hierarchy() const override { return false; } + protected: void initialize() override; diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index fd642753f..8eee2d403 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -11,6 +11,9 @@ #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" +#include "src/hierarchies/updaters/abstract_updater.h" #include "src/utils/rng.h" //! Abstract base class for a hierarchy object. @@ -48,6 +51,10 @@ class AbstractHierarchy { public: + virtual void set_likelihood(std::shared_ptr like_) = 0; + virtual void set_prior(std::shared_ptr prior_) = 0; + virtual void set_updater(std::shared_ptr updater_) = 0; + virtual ~AbstractHierarchy() = default; //! Returns an independent, data-less copy of this object diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 7c921dc94..a7dbf4931 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -39,12 +39,32 @@ class BaseHierarchy : public AbstractHierarchy { public: using HyperParams = decltype(prior->get_hypers()); - BaseHierarchy() = default; + + BaseHierarchy(std::shared_ptr like_ = nullptr, + std::shared_ptr prior_ = nullptr, + std::shared_ptr updater_ = nullptr) { + if (like_) { + set_likelihood(like_); + } + if (prior_) { + set_prior(prior_); + } + if (updater_) { + set_updater(updater_); + } + } + ~BaseHierarchy() = default; - void set_likelihood(std::shared_ptr like_) { like = like_; }; - void set_prior(std::shared_ptr prior_) { prior = prior_; }; - void set_updater(std::shared_ptr updater_) { updater = updater_; }; + void set_likelihood(std::shared_ptr like_) override { + like = std::static_pointer_cast(like_); + } + void set_prior(std::shared_ptr prior_) override { + prior = std::static_pointer_cast(prior_); + } + void set_updater(std::shared_ptr updater_) override { + updater = std::static_pointer_cast(updater_); + }; std::shared_ptr clone() const override { // Create copy of the hierarchy @@ -397,7 +417,7 @@ class BaseHierarchy : public AbstractHierarchy { // } // //! Down-casts the given generic proto message to a ClusterState proto -// const bayesmix::AlgorithmState::ClusterState &downcast_state( +// const bayesmix::AlgorithmState::ClusterState &f( // const google::protobuf::Message &state_) const { // return google::protobuf::internal::down_cast< // const bayesmix::AlgorithmState::ClusterState &>(state_); diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 7444c0116..fbfcdf24a 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -1,2 +1,7 @@ -target_sources(bayesmix PUBLIC abstract_likelihood.h base_likelihood.h states - .h uni_norm_likelihood.h uni_norm_likelihood.cc) +target_sources(bayesmix PUBLIC + abstract_likelihood.h + base_likelihood.h + states.h + uni_norm_likelihood.h + uni_norm_likelihood.cc +) diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index dbd841020..293c7e334 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -15,15 +15,17 @@ #include "base_hierarchy.h" #include "likelihoods/uni_norm_likelihood.h" #include "priors/nig_prior_model.h" -#include "updaters/nnig_updater.h" +// #include "updaters/nnig_updater.h" #include "updaters/random_walk_updater.h" class NNIGHierarchy : public BaseHierarchy { public: - NNIGHierarchy() = default; ~NNIGHierarchy() = default; + using BaseHierarchy::BaseHierarchy; + bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNIG; } diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt index 726c08f49..196c9bf04 100644 --- a/src/hierarchies/priors/CMakeLists.txt +++ b/src/hierarchies/priors/CMakeLists.txt @@ -1,2 +1,7 @@ -target_sources(bayesmix PUBLIC abstract_prior_model.h base_prior_model - .h hyperparams.h nig_prior_model.h nig_prior_model.cc) +target_sources(bayesmix PUBLIC + abstract_prior_model.h + base_prior_model.h + hyperparams.h + nig_prior_model.h + nig_prior_model.cc +) diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index d2eab6991..10caa57eb 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -95,7 +95,10 @@ std::shared_ptr NIGPriorModel::sample( bool use_post_hypers) { auto &rng = bayesmix::Rng::Instance().get(); Hyperparams::NIG params = use_post_hypers ? post_hypers : hypers; + std::cout << "use_post_hypers: " << use_post_hypers << std::endl; + std::cout << "shape: " << params.shape << ", scale: " << params.scale + << std::endl; double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); double mean = stan::math::normal_rng(params.mean, sqrt(var / params.var_scaling), rng); diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index e3536da34..b366686f9 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,5 +1,8 @@ -target_sources(bayesmix PUBLIC abstract_updater - .h -#conjugate_updater.h - nnig_updater.h nnig_updater.cc metropolis_updater - .h random_walk_updater.h) +target_sources(bayesmix PUBLIC + abstract_updater.h + #conjugate_updater.h + # nnig_updater.h + # nnig_updater.cc + metropolis_updater.h + random_walk_updater.h +) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index f8e3f7d53..830eb5429 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -18,8 +18,8 @@ class AbstractUpdater { AbstractPriorModel &prior) = 0; virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, bool update_params) = 0; - virtual void compute_posterior_hypers(UniNormLikelihood &like, - NIGPriorModel &prior) { + virtual void compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) { throw std::runtime_error("compute_posterior_hypers not implemented"); } }; diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index 09fc5a1e9..c097dd663 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -15,7 +15,8 @@ class NNIGUpdater { bool is_conjugate() const { return true; }; void draw(UniNormLikelihood& like, NIGPriorModel& prior, bool update_params); void initialize(UniNormLikelihood& like, NIGPriorModel& prior); - void compute_posterior_hypers(UniNormLikelihood& like, NIGPriorModel& prior); + void compute_posterior_hypers(UniNormLikelihood& like, + NIGPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_NNIG_UPDATER_H_ From adba2694e7aaa248e5702b29bdecea42c8a4b858 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:44:56 +0100 Subject: [PATCH 092/317] Prepared notebook for NNxIG hierarchy --- python/notebooks/gaussian_mix_NNxIG.ipynb | 82 ++++++++++++++--------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/python/notebooks/gaussian_mix_NNxIG.ipynb b/python/notebooks/gaussian_mix_NNxIG.ipynb index 9e92c562e..96e3fa370 100644 --- a/python/notebooks/gaussian_mix_NNxIG.ipynb +++ b/python/notebooks/gaussian_mix_NNxIG.ipynb @@ -3,20 +3,25 @@ { "cell_type": "code", "execution_count": null, - "id": "49d3291e", + "id": "6c73fa6a", "metadata": {}, "outputs": [], "source": [ - "import os\n", - "os.environ[\"BAYESMIX_EXE\"] = \"/home/m_gianella/Documents/GitHub/bayesmix/build/run_mcmc\"\n", - "\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", + "import numpy as np\n", "\n", - "from bayesmixpy import run_mcmc\n", - "from tensorflow_probability.substrates import numpy as tfp\n", - "tfd = tfp.distributions" + "from bayesmixpy import run_mcmc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79825dfc", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"BAYESMIX_EXE\"] = \"../../build/run_mcmc\"" ] }, { @@ -26,17 +31,15 @@ "metadata": {}, "outputs": [], "source": [ - "np.random.seed(123)\n", - "\n", - "# Set true parameters\n", - "N = 500\n", - "Ncomp = 3\n", - "means = [-5.0, 0.0, 5.0]\n", - "sds = [0.5, 2.0, 0.25]\n", - "weights = np.ones(Ncomp)/Ncomp\n", + "# Generate data\n", + "data = np.concatenate([\n", + " np.random.normal(loc=3, scale=1, size=100),\n", + " np.random.normal(loc=-3, scale=1, size=100),\n", + "])\n", "\n", - "cluster_allocs = tfd.Categorical(probs=weights).sample(N)\n", - "data = np.stack([tfd.Normal(means[cluster_allocs[i]], sds[cluster_allocs[i]]).sample() for i in range(N)])" + "# Plot data\n", + "plt.hist(data)\n", + "plt.show()" ] }, { @@ -46,7 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Setup parameters for bayesmixpy\n", + "# Hierarchy settings\n", "hier_params = \\\n", "\"\"\"\n", "fixed_values {\n", @@ -57,24 +60,27 @@ "}\n", "\"\"\"\n", "\n", + "# Mixing settings\n", "mix_params = \\\n", "\"\"\"\n", - "dp_prior {\n", - " totalmass: 1\n", + "fixed_value {\n", + " totalmass: 1.0\n", "}\n", - "num_components: 3\n", "\"\"\"\n", "\n", + "# Algorithm settings\n", "algo_params = \\\n", "\"\"\"\n", - "algo_id: \"BlockedGibbs\"\n", + "algo_id: \"Neal8\"\n", "rng_seed: 20201124\n", "iterations: 2000\n", "burnin: 1000\n", "init_num_clusters: 3\n", + "neal8_n_aux: 3\n", "\"\"\"\n", "\n", - "dens_grid = np.linspace(-7.5,7.5,1000)" + "# Evaluation grid\n", + "dens_grid = np.linspace(-6.5, 6.5, 1000)" ] }, { @@ -85,26 +91,36 @@ "outputs": [], "source": [ "# Fit model using bayesmixpy\n", - "eval_dens, n_clus, clus_chain, best_clus = run_mcmc(\"NNxIG\",\"TruncSB\", data,\n", + "eval_dens, n_clus, clus_chain, best_clus = run_mcmc(\"NNxIG\",\"DP\", data,\n", " hier_params, mix_params, algo_params,\n", - " dens_grid, return_num_clusters=False,\n", - " return_clusters=False, return_best_clus=True)" + " dens_grid, return_num_clusters=True,\n", + " return_clusters=True, return_best_clus=True)" ] }, { "cell_type": "code", "execution_count": null, - "id": "1eb6c0e9", + "id": "8470fc0a", "metadata": {}, "outputs": [], "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n", + "\n", + "# Posterior distribution of clusters\n", + "x, y = np.unique(n_clus, return_counts=True)\n", + "axes[0].bar(x, y / y.sum())\n", + "axes[0].set_xticks(x)\n", + "axes[0].set_title(\"Posterior distribution of the number of clusters\")\n", + "\n", "# Plot mean posterior density\n", - "plt.plot(dens_grid, np.exp(eval_dens.mean(axis=0)))\n", - "plt.hist(data, alpha=0.4, density=True)\n", + "axes[1].plot(dens_grid, np.exp(np.mean(eval_dens, axis=0)))\n", + "axes[1].hist(data, alpha=0.3, density=True)\n", "for c in np.unique(best_clus):\n", " data_in_clus = data[best_clus == c]\n", - " plt.scatter(data_in_clus, np.zeros_like(data_in_clus) + 0.01)\n", - "plt.title(\"Posterior estimated density\")\n", + " axes[1].scatter(data_in_clus, np.zeros_like(data_in_clus) + 0.01)\n", + "axes[1].set_title(\"Posterior density estimate\")\n", + "\n", + "# Show results\n", "plt.show()" ] } From f0075c7f213362cd49521c60951c31b73c5a3689 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:45:20 +0100 Subject: [PATCH 093/317] Allow Neal8 to work with non-conjugate hierarchies --- src/algorithms/neal8_algorithm.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/algorithms/neal8_algorithm.h b/src/algorithms/neal8_algorithm.h index 55479db1a..045c3e4a9 100644 --- a/src/algorithms/neal8_algorithm.h +++ b/src/algorithms/neal8_algorithm.h @@ -23,6 +23,8 @@ class Neal8Algorithm : public Neal2Algorithm { Neal8Algorithm() = default; ~Neal8Algorithm() = default; + bool requires_conjugate_hierarchy() const override { return false; } + //! Returns number of auxiliary blocks unsigned int get_n_aux() const { return n_aux; } From 5265a32f0e7f36eeeb5117588f13eb243b73d22b Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:45:37 +0100 Subject: [PATCH 094/317] Add tests for NxIG prior model --- test/prior_models.cc | 90 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 4 deletions(-) diff --git a/test/prior_models.cc b/test/prior_models.cc index 3c55ebd70..dc392b9e0 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -6,9 +6,8 @@ #include "algorithm_state.pb.h" #include "hierarchy_prior.pb.h" -// #include "ls_state.pb.h" #include "src/hierarchies/priors/nig_prior_model.h" -// #include "src/utils/rng.h" +#include "src/hierarchies/priors/nxig_prior_model.h" TEST(nig_prior_model, set_get_hypers) { // Instance @@ -109,6 +108,7 @@ TEST(nig_prior_model, normal_mean_prior) { TEST(nig_prior_model, sample) { // Instance auto prior = std::make_shared(); + bool use_post_hypers = true; // Define prior hypers bayesmix::AlgorithmState::HierarchyHypers hypers_proto; @@ -119,9 +119,91 @@ TEST(nig_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(false); - auto state2 = prior->sample(false); + auto state1 = prior->sample(!use_post_hypers); + auto state2 = prior->sample(!use_post_hypers); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); } + +TEST(nxig_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::NxIGDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + // Prepare hypers + hypers_.set_mean(5.0); + hypers_.set_var(1.2); + hypers_.set_shape(4.0); + hypers_.set_scale(3.0); + set_state_.mutable_nnxig_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(nxig_prior_model, fixed_values_prior) { + // Prepare buffers + bayesmix::NNxIGPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var(1.2); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnxig_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnxig_state().DebugString()); + } +} + +TEST(nxig_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + bool use_post_hypers = true; + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + hypers_proto.mutable_nnxig_state()->set_mean(5.0); + hypers_proto.mutable_nnxig_state()->set_var(1.2); + hypers_proto.mutable_nnxig_state()->set_shape(4.0); + hypers_proto.mutable_nnxig_state()->set_scale(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(!use_post_hypers); + auto state2 = prior->sample(!use_post_hypers); + + // Check if they coincides + ASSERT_TRUE(state1->DebugString() != state2->DebugString()); +} \ No newline at end of file From 21f38628df0b093da4a50422525630489ca99850 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:46:02 +0100 Subject: [PATCH 095/317] Tests for NNxIG hierarchy --- test/hierarchies.cc | 94 +++++++++++++++++++++------------------------ 1 file changed, 43 insertions(+), 51 deletions(-) diff --git a/test/hierarchies.cc b/test/hierarchies.cc index a9b98d79b..b23819186 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -8,11 +8,11 @@ // #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" // #include "src/hierarchies/nnw_hierarchy.h" -// #include "src/hierarchies/nnxig_hierarchy.h" +#include "src/hierarchies/nnxig_hierarchy.h" #include "src/utils/proto_utils.h" #include "src/utils/rng.h" -TEST(nnighierarchy, draw) { +TEST(nnig_hierarchy, draw) { auto hier = std::make_shared(); bayesmix::NNIGPrior prior; double mu0 = 5.0; @@ -38,7 +38,7 @@ TEST(nnighierarchy, draw) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -TEST(nnighierarchy, sample_given_data) { +TEST(nnig_hierarchy, sample_given_data) { auto hier = std::make_shared(); bayesmix::NNIGPrior prior; double mu0 = 5.0; @@ -57,7 +57,7 @@ TEST(nnighierarchy, sample_given_data) { datum << 4.5; auto hier2 = hier->clone(); - hier2->add_datum(0, datum, false); + hier2->add_datum(0, datum, true); hier2->sample_full_cond(); bayesmix::AlgorithmState out; @@ -223,59 +223,51 @@ TEST(nnighierarchy, sample_given_data) { // } // } -// TEST(nnxighierarchy, draw) { -// auto hier = std::make_shared(); -// bayesmix::NNxIGPrior prior; -// double mu0 = 5.0; -// double var0 = 1.0; -// double alpha0 = 2.0; -// double beta0 = 2.0; -// prior.mutable_fixed_values()->set_mean(mu0); -// prior.mutable_fixed_values()->set_var(var0); -// prior.mutable_fixed_values()->set_shape(alpha0); -// prior.mutable_fixed_values()->set_scale(beta0); -// hier->get_mutable_prior()->CopyFrom(prior); -// hier->initialize(); +TEST(nnxig_hierarchy, draw) { + auto hier = std::make_shared(); + bayesmix::NNxIGPrior prior; + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var(1.2); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); + hier->get_mutable_prior()->CopyFrom(prior); + hier->initialize(); -// auto hier2 = hier->clone(); -// hier2->sample_prior(); + auto hier2 = hier->clone(); + hier2->sample_prior(); -// bayesmix::AlgorithmState out; -// bayesmix::AlgorithmState::ClusterState* clusval = -// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 -// = out.add_cluster_states(); hier->write_state_to_proto(clusval); -// hier2->write_state_to_proto(clusval2); + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = + out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 + = out.add_cluster_states(); hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); -// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -// } + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} -// TEST(nnxighierarchy, sample_given_data) { -// auto hier = std::make_shared(); -// bayesmix::NNxIGPrior prior; -// double mu0 = 5.0; -// double var0 = 1.0; -// double alpha0 = 2.0; -// double beta0 = 2.0; -// prior.mutable_fixed_values()->set_mean(mu0); -// prior.mutable_fixed_values()->set_var(var0); -// prior.mutable_fixed_values()->set_shape(alpha0); -// prior.mutable_fixed_values()->set_scale(beta0); -// hier->get_mutable_prior()->CopyFrom(prior); +TEST(nnxig_hierarchy, sample_given_data) { + auto hier = std::make_shared(); + bayesmix::NNxIGPrior prior; + prior.mutable_fixed_values()->set_mean(5.0); + prior.mutable_fixed_values()->set_var(1.2); + prior.mutable_fixed_values()->set_shape(2.0); + prior.mutable_fixed_values()->set_scale(2.0); + hier->get_mutable_prior()->CopyFrom(prior); -// hier->initialize(); + hier->initialize(); -// Eigen::VectorXd datum(1); -// datum << 4.5; + Eigen::VectorXd datum(1); + datum << 4.5; -// auto hier2 = hier->clone(); -// hier2->add_datum(0, datum, false); -// hier2->sample_full_cond(); + auto hier2 = hier->clone(); + hier2->add_datum(0, datum, true); + hier2->sample_full_cond(); -// bayesmix::AlgorithmState out; -// bayesmix::AlgorithmState::ClusterState* clusval = -// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 -// = out.add_cluster_states(); hier->write_state_to_proto(clusval); -// hier2->write_state_to_proto(clusval2); + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); -// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -// } + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} From 6187428dcf1c3587daefcd87af60ef5a1398c770 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:46:29 +0100 Subject: [PATCH 096/317] Add NNxIG updater to target bayesmix --- src/hierarchies/updaters/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 0b4135292..3efd34b34 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -4,4 +4,6 @@ target_sources(bayesmix conjugate_updater.h nnig_updater.h nnig_updater.cc + nnxig_updater.h + nnxig_updater.cc ) From 595d1042f1cf52bffa1a8c605538a00d89f9c4fa Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:47:00 +0100 Subject: [PATCH 097/317] Added NxIG prior model to target bayesmix --- src/hierarchies/priors/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt index 8547a6acf..7ed26bdeb 100644 --- a/src/hierarchies/priors/CMakeLists.txt +++ b/src/hierarchies/priors/CMakeLists.txt @@ -5,4 +5,6 @@ target_sources(bayesmix hyperparams.h nig_prior_model.h nig_prior_model.cc + nxig_prior_model.h + nxig_prior_model.cc ) From 7cc58bf7ad4e303c37c53ba94bc04485f52560d6 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:48:43 +0100 Subject: [PATCH 098/317] Add NNxIG hierarchy to target bayesmix --- src/hierarchies/CMakeLists.txt | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index 234a44f6d..537299476 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -2,15 +2,13 @@ target_sources(bayesmix PUBLIC abstract_hierarchy.h base_hierarchy.h + nnig_hierarchy.h + nnxig_hierarchy.h # conjugate_hierarchy.h # lin_reg_uni_hierarchy.h # lin_reg_uni_hierarchy.cc - nnig_hierarchy.h - # nnig_hierarchy.cc # nnw_hierarchy.h # nnw_hierarchy.cc - # nnxig_hierarchy.h - # nnxig_hierarchy.cc ) add_subdirectory(likelihoods) From b4efa84abdb4d4490349bef6da6c66515b98b880 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:49:04 +0100 Subject: [PATCH 099/317] Now builder load NNxIG hierarchy --- src/hierarchies/load_hierarchies.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index 982960f0a..e42a3af92 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -9,7 +9,7 @@ // #include "lin_reg_uni_hierarchy.h" #include "nnig_hierarchy.h" // #include "nnw_hierarchy.h" -// #include "nnxig_hierarchy.h" +#include "nnxig_hierarchy.h" #include "src/runtime/factory.h" //! Loads all available `Hierarchy` objects into the appropriate factory, so @@ -26,9 +26,9 @@ __attribute__((constructor)) static void load_hierarchies() { Builder NNIGbuilder = []() { return std::make_shared(); }; - // Builder NNxIGbuilder = []() { - // return std::make_shared(); - // }; + Builder NNxIGbuilder = []() { + return std::make_shared(); + }; // Builder NNWbuilder = []() { // return std::make_shared(); // }; @@ -37,7 +37,7 @@ __attribute__((constructor)) static void load_hierarchies() { // }; factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder); - // factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); + factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); // factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); // factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); } From ec00d7967dbb4bb0a9f4d71989d4786fc431efec Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 26 Jan 2022 22:49:32 +0100 Subject: [PATCH 100/317] Added NNxIG hierarchy via composition --- src/hierarchies/nnxig_hierarchy.h | 42 ++++++++++++ src/hierarchies/priors/nxig_prior_model.cc | 77 ++++++++++++++++++++++ src/hierarchies/priors/nxig_prior_model.h | 41 ++++++++++++ src/hierarchies/updaters/nnxig_updater.cc | 38 +++++++++++ src/hierarchies/updaters/nnxig_updater.h | 18 +++++ 5 files changed, 216 insertions(+) create mode 100644 src/hierarchies/nnxig_hierarchy.h create mode 100644 src/hierarchies/priors/nxig_prior_model.cc create mode 100644 src/hierarchies/priors/nxig_prior_model.h create mode 100644 src/hierarchies/updaters/nnxig_updater.cc create mode 100644 src/hierarchies/updaters/nnxig_updater.h diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h new file mode 100644 index 000000000..9d5e7ff9f --- /dev/null +++ b/src/hierarchies/nnxig_hierarchy.h @@ -0,0 +1,42 @@ +#ifndef BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ + +// #include + +// #include +// #include +// #include + +// #include "algorithm_state.pb.h" +// #include "conjugate_hierarchy.h" +#include "hierarchy_id.pb.h" +// #include "hierarchy_prior.pb.h" + +#include "base_hierarchy.h" +#include "likelihoods/uni_norm_likelihood.h" +#include "priors/nxig_prior_model.h" +#include "updaters/nnxig_updater.h" + +class NNxIGHierarchy : public BaseHierarchy { + public: + NNxIGHierarchy() = default; + ~NNxIGHierarchy() = default; + + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::NNxIG; + } + + void initialize_state() override { + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::UniLS state; + state.mean = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + like->set_state(state); + }; + +}; + +#endif // BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ diff --git a/src/hierarchies/priors/nxig_prior_model.cc b/src/hierarchies/priors/nxig_prior_model.cc new file mode 100644 index 000000000..2729fe67e --- /dev/null +++ b/src/hierarchies/priors/nxig_prior_model.cc @@ -0,0 +1,77 @@ +#include "nxig_prior_model.h" + +void NxIGPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers.mean = prior->fixed_values().mean(); + hypers.var = prior->fixed_values().var(); + hypers.shape = prior->fixed_values().shape(); + hypers.scale = prior->fixed_values().scale(); + // Check validity + if (hypers.var <= 0) { + throw std::invalid_argument("Variance parameter must be > 0"); + } + if (hypers.shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers.scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +double NxIGPriorModel::lpdf(const google::protobuf::Message &state_) { + auto &state = downcast_state(state_).uni_ls_state(); + double target = + stan::math::normal_lpdf(state.mean(), hypers.mean,sqrt(hypers.var)) + + stan::math::inv_gamma_lpdf(state.var(), hypers.shape, hypers.scale); + return target; +} + +std::shared_ptr NxIGPriorModel::sample( + bool use_post_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + Hyperparams::NxIG params = use_post_hypers ? post_hypers : hypers; + + double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); + double mean = + stan::math::normal_rng(params.mean, sqrt(params.var), rng); + + bayesmix::AlgorithmState::ClusterState state; + state.mutable_uni_ls_state()->set_mean(mean); + state.mutable_uni_ls_state()->set_var(var); + return std::make_shared(state); +}; + +void NxIGPriorModel::update_hypers( + const std::vector &states) { + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NxIGPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).nnxig_state(); + hypers.mean = hyperscast.mean(); + hypers.var = hyperscast.var(); + hypers.scale = hyperscast.scale(); + hypers.shape = hyperscast.shape(); +} + +std::shared_ptr +NxIGPriorModel::get_hypers_proto() const { + bayesmix::NxIGDistribution hypers_; + hypers_.set_mean(hypers.mean); + hypers_.set_var(hypers.var); + hypers_.set_shape(hypers.shape); + hypers_.set_scale(hypers.scale); + + auto out = std::make_shared(); + out->mutable_nnxig_state()->CopyFrom(hypers_); + return out; +} diff --git a/src/hierarchies/priors/nxig_prior_model.h b/src/hierarchies/priors/nxig_prior_model.h new file mode 100644 index 000000000..a7282187a --- /dev/null +++ b/src/hierarchies/priors/nxig_prior_model.h @@ -0,0 +1,41 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_NXIG_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_NXIG_PRIOR_MODEL_H_ + +// #include + +#include +#include +#include +#include + +// #include "algorithm_state.pb.h" +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +class NxIGPriorModel : public BasePriorModel { + public: + NxIGPriorModel() = default; + ~NxIGPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + std::shared_ptr sample( + bool use_post_hypers) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + protected: + std::shared_ptr get_hypers_proto() + const override; + + void initialize_hypers() override; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_NXIG_PRIOR_MODEL_H_ diff --git a/src/hierarchies/updaters/nnxig_updater.cc b/src/hierarchies/updaters/nnxig_updater.cc new file mode 100644 index 000000000..3c7064081 --- /dev/null +++ b/src/hierarchies/updaters/nnxig_updater.cc @@ -0,0 +1,38 @@ +#include "nnxig_updater.h" + +#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/priors/hyperparams.h" + +void NNxIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + auto state = likecast.get_state(); + int card = likecast.get_card(); + double data_sum = likecast.get_data_sum(); + double data_sum_squares = likecast.get_data_sum_squares(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + priorcast.set_posterior_hypers(hypers); + } + + // Compute posterior hyperparameters + Hyperparams::NxIG post_params; + double var_y = data_sum_squares - 2 * state.mean * data_sum + + card * state.mean * state.mean; + post_params.mean = (hypers.var * data_sum + state.var * hypers.mean) / + (card * hypers.var + state.var); + post_params.var = + (state.var * hypers.var) / (card * hypers.var + state.var); + post_params.shape = hypers.shape + 0.5 * card; + post_params.scale = hypers.scale + 0.5 * var_y; + priorcast.set_posterior_hypers(post_params); + return; +}; + + diff --git a/src/hierarchies/updaters/nnxig_updater.h b/src/hierarchies/updaters/nnxig_updater.h new file mode 100644 index 000000000..15e680db8 --- /dev/null +++ b/src/hierarchies/updaters/nnxig_updater.h @@ -0,0 +1,18 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ + +#include "conjugate_updater.h" +#include "src/hierarchies/likelihoods/uni_norm_likelihood.h" +#include "src/hierarchies/priors/nxig_prior_model.h" + +class NNxIGUpdater : public ConjugateUpdater { + public: + NNxIGUpdater() = default; + ~NNxIGUpdater() = default; + + bool is_conjugate() const override { return false; }; + void compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) override; +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ From 8db85790a4d80c427bf7df7f2e28f727bbf76065 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 27 Jan 2022 08:42:05 +0100 Subject: [PATCH 101/317] multivariate ls state --- CMakeLists.txt | 3 ++ src/hierarchies/likelihoods/states.h | 46 +++++++++++++++++++++++++++- test/CMakeLists.txt | 1 + test/likelihoods.cc | 17 ++++++++++ 4 files changed, 66 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ba0d0eee..5d600cce7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,7 @@ set(INCLUDE_PATHS ${CMAKE_CURRENT_LIST_DIR}/lib/math ${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/boost_1.72.0 ${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/eigen_3.3.9 + ${CMAKE_CURRENT_LIST_DIR}/lib/math/lib/sundials_5.5.0/include # TBB already included ${CMAKE_CURRENT_BINARY_DIR} ${protobuf_SOURCE_DIR}/src @@ -157,6 +158,8 @@ if (BUILD_RUN) target_compile_options(run_mcmc PUBLIC ${COMPILE_OPTIONS}) endif() +add_subdirectory(test) + if (NOT DISABLE_TESTS) add_subdirectory(test) endif() diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states.h index 650fe0003..0c4ffdcf8 100644 --- a/src/hierarchies/likelihoods/states.h +++ b/src/hierarchies/likelihoods/states.h @@ -5,6 +5,7 @@ #include #include "algorithm_state.pb.h" +#include "src/utils/proto_utils.h" namespace State { @@ -42,10 +43,53 @@ class UniLS { } }; -struct MultiLS { +class MultiLS { + public: Eigen::VectorXd mean; Eigen::MatrixXd prec, prec_chol; double prec_logdet; + + Eigen::VectorXd get_unconstrained() { + Eigen::VectorXd out_prec = stan::math::cov_matrix_free(prec); + Eigen::VectorXd out(mean.size() + out_prec.size()); + out << mean, out_prec; + return out; + } + + void set_from_unconstrained(Eigen::VectorXd in) { + double dim_ = 0.5 * (std::sqrt(8 * in.size() + 9) - 3); + double dim; + assert(modf(dim_, &dim) == 0.0); + mean = in.head(int(dim)); + prec = + stan::math::cov_matrix_constrain(in.tail(int(in.size() - dim)), dim); + prec_chol = Eigen::LLT(prec).matrixL(); + Eigen::VectorXd diag = prec_chol.diagonal(); + prec_logdet = 2 * log(diag.array()).sum(); + } + + void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + mean = to_eigen(state_.multi_ls_state().mean()); + prec = to_eigen(state_.multi_ls_state().prec()); + prec_chol = to_eigen(state_.multi_ls_state().prec_chol()); + Eigen::VectorXd diag = prec_chol.diagonal(); + prec_logdet = 2 * log(diag.array()).sum(); + } + + bayesmix::AlgorithmState::ClusterState get_as_proto() { + bayesmix::AlgorithmState::ClusterState state; + bayesmix::to_proto(mean, state.mutable_multi_ls_state()->mutable_mean()); + bayesmix::to_proto(prec, state.mutable_multi_ls_state()->mutable_prec()); + bayesmix::to_proto(prec_chol, + state.mutable_multi_ls_state()->mutable_prec_chol()); + return state; + } + + double log_det_jac() { + double out = 0; + stan::math::positive_constrain(stan::math::cov_matrix_free(prec), out); + return out; + } }; struct UniLinReg { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 22118d049..26b5eb903 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -30,6 +30,7 @@ add_executable(test_bayesmix $ # runtime.cc # rng.cc # logit_sb.cc + gradient.cc ) target_include_directories(test_bayesmix PUBLIC ${INCLUDE_PATHS}) diff --git a/test/likelihoods.cc b/test/likelihoods.cc index c0bdfb276..8335d653d 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -108,3 +108,20 @@ TEST(uni_norm_likelihood, eval_lpdf_unconstrained) { clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5); } + +TEST(multi_ls_state, set_unconstrained) { + auto& rng = bayesmix::Rng::Instance().get(); + + State::MultiLS state; + auto mean = Eigen::VectorXd::Zero(5); + auto prec = + stan::math::wishart_rng(10, Eigen::MatrixXd::Identity(5, 5), rng); + state.mean = mean; + state.prec = prec; + Eigen::VectorXd unconstrained_state = state.get_unconstrained(); + + State::MultiLS state2; + state2.set_from_unconstrained(unconstrained_state); + ASSERT_TRUE((state.mean - state2.mean).squaredNorm() < 1e-5); + ASSERT_TRUE((state.prec - state2.prec).squaredNorm() < 1e-5); +} From 5d30e1d59382f5f5d2a6138418ec71e8691537db Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 27 Jan 2022 09:02:38 +0100 Subject: [PATCH 102/317] using lower cholesky factor --- src/hierarchies/likelihoods/states.h | 12 +++++++++--- src/utils/distributions.cc | 22 ++++++++++++---------- test/CMakeLists.txt | 2 +- test/distributions.cc | 15 +++++++++------ 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states.h index 0c4ffdcf8..530c90123 100644 --- a/src/hierarchies/likelihoods/states.h +++ b/src/hierarchies/likelihoods/states.h @@ -56,6 +56,14 @@ class MultiLS { return out; } + void set_from_constrained(Eigen::VectorXd mean_, Eigen::MatrixXd prec_) { + mean = mean_; + prec = prec_; + prec_chol = Eigen::LLT(prec).matrixL(); + Eigen::VectorXd diag = prec_chol.diagonal(); + prec_logdet = 2 * log(diag.array()).sum(); + } + void set_from_unconstrained(Eigen::VectorXd in) { double dim_ = 0.5 * (std::sqrt(8 * in.size() + 9) - 3); double dim; @@ -63,9 +71,7 @@ class MultiLS { mean = in.head(int(dim)); prec = stan::math::cov_matrix_constrain(in.tail(int(in.size() - dim)), dim); - prec_chol = Eigen::LLT(prec).matrixL(); - Eigen::VectorXd diag = prec_chol.diagonal(); - prec_logdet = 2 * log(diag.array()).sum(); + set_from_constrained(mean, prec); } void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { diff --git a/src/utils/distributions.cc b/src/utils/distributions.cc index ad864d279..52bcdf665 100644 --- a/src/utils/distributions.cc +++ b/src/utils/distributions.cc @@ -16,22 +16,24 @@ double bayesmix::multi_normal_prec_lpdf(const Eigen::VectorXd &datum, const Eigen::MatrixXd &prec_chol, double prec_logdet) { using stan::math::NEG_LOG_SQRT_TWO_PI; - double base = prec_logdet + NEG_LOG_SQRT_TWO_PI * datum.size(); - double exp = (prec_chol * (datum - mean)).squaredNorm(); - return 0.5 * (base - exp); + double base = 0.5 * prec_logdet + NEG_LOG_SQRT_TWO_PI * datum.size(); + double exp = 0.5 * (prec_chol.transpose() * (datum - mean)).squaredNorm(); + return base - exp; } Eigen::VectorXd bayesmix::multi_normal_prec_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, double prec_logdet) { using stan::math::NEG_LOG_SQRT_TWO_PI; - Eigen::VectorXd exp = - ((data.rowwise() - mean.transpose()) * prec_chol.transpose()) - .rowwise() - .squaredNorm(); - Eigen::VectorXd base = Eigen::ArrayXd::Ones(data.rows()) * prec_logdet + - NEG_LOG_SQRT_TWO_PI * data.cols(); - return (base - exp) * 0.5; + + Eigen::VectorXd exp = ((data.rowwise() - mean.transpose()) * prec_chol) + .rowwise() + .squaredNorm() * + 0.5; + Eigen::VectorXd base = + Eigen::ArrayXd::Ones(data.rows()) * prec_logdet * 0.5 + + NEG_LOG_SQRT_TWO_PI * data.cols(); + return base - exp; } double bayesmix::multi_student_t_invscale_lpdf( diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 26b5eb903..d58fb0afb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -24,7 +24,7 @@ add_executable(test_bayesmix $ # lpdf.cc # priors.cc // OLD, USEREI prior_models.cc # eigen_utils.cc - # distributions.cc + distributions.cc # semi_hdp.cc # collectors.cc # runtime.cc diff --git a/test/distributions.cc b/test/distributions.cc index ee1e480ef..acfbe7e31 100644 --- a/test/distributions.cc +++ b/test/distributions.cc @@ -6,6 +6,7 @@ #include #include +#include "src/hierarchies/likelihoods/states.h" #include "src/utils/rng.h" TEST(mix_dist, 1) { @@ -152,17 +153,19 @@ TEST(mult_normal, lpdf_grid) { Eigen::MatrixXd tmp = Eigen::MatrixXd::Random(dim + 1, dim); Eigen::MatrixXd prec = tmp.transpose() * tmp + Eigen::MatrixXd::Identity(dim, dim); - Eigen::MatrixXd prec_chol = Eigen::LLT(prec).matrixU(); - Eigen::VectorXd diag = prec_chol.diagonal(); - double prec_logdet = 2 * log(diag.array()).sum(); + State::MultiLS state; + state.set_from_constrained(mean, prec); Eigen::VectorXd lpdfs = bayesmix::multi_normal_prec_lpdf_grid( - data, mean, prec_chol, prec_logdet); + data, state.mean, state.prec_chol, state.prec_logdet); for (int i = 0; i < 20; i++) { - double curr = bayesmix::multi_normal_prec_lpdf(data.row(i), mean, - prec_chol, prec_logdet); + double curr = bayesmix::multi_normal_prec_lpdf( + data.row(i), state.mean, state.prec_chol, state.prec_logdet); + double curr2 = stan::math::multi_normal_prec_lpdf(data.row(i), state.mean, + state.prec); ASSERT_DOUBLE_EQ(curr, lpdfs(i)); + ASSERT_DOUBLE_EQ(curr, curr2); } } From 03b38f37f6153d7615f382d7b1d00f5c3bf65c34 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 27 Jan 2022 09:09:52 +0100 Subject: [PATCH 103/317] Fix header guards --- src/hierarchies/likelihoods/abstract_likelihood.h | 6 +++--- src/hierarchies/likelihoods/base_likelihood.h | 6 +++--- src/hierarchies/likelihoods/states.h | 6 +++--- src/hierarchies/likelihoods/uni_norm_likelihood.h | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 9beef7729..8fcf3bbb7 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -1,5 +1,5 @@ -#ifndef BAYESMIX_HIERARCHIES_ABSTRACT_LIKELIHOOD_H_ -#define BAYESMIX_HIERARCHIES_ABSTRACT_LIKELIHOOD_H_ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_ABSTRACT_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_ABSTRACT_LIKELIHOOD_H_ #include @@ -114,4 +114,4 @@ class AbstractLikelihood { } }; -#endif // BAYESMIX_HIERARCHIES_ABSTRACT_LIKELIHOOD_H_ +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_ABSTRACT_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index a68f7c823..fe806f377 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -1,5 +1,5 @@ -#ifndef BAYESMIX_HIERARCHIES_BASE_LIKELIHOOD_H_ -#define BAYESMIX_HIERARCHIES_BASE_LIKELIHOOD_H_ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ #include @@ -141,4 +141,4 @@ Eigen::VectorXd BaseLikelihood::lpdf_grid( return lpdf; } -#endif // BAYESMIX_HIERARCHIES_BASE_LIKELIHOOD_H_ +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states.h index 70dc7f032..3acc73838 100644 --- a/src/hierarchies/likelihoods/states.h +++ b/src/hierarchies/likelihoods/states.h @@ -1,5 +1,5 @@ -#ifndef BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ -#define BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_H_ #include @@ -22,4 +22,4 @@ struct UniLinReg { } // namespace State -#endif // BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_H_ diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 133e1f81e..ac5e044e9 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -1,5 +1,5 @@ -#ifndef BAYESMIX_HIERARCHIES_UNI_NORM_LIKELIHOOD_H_ -#define BAYESMIX_HIERARCHIES_UNI_NORM_LIKELIHOOD_H_ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_NORM_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_NORM_LIKELIHOOD_H_ #include @@ -35,4 +35,4 @@ class UniNormLikelihood double data_sum_squares = 0; }; -#endif // BAYESMIX_HIERARCHIES_UNI_NORM_LIKELIHOOD_H_ +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_NORM_LIKELIHOOD_H_ From 31fc8000833e8fad62a58d2357ab4d78e39c6550 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 27 Jan 2022 09:12:13 +0100 Subject: [PATCH 104/317] Fix header guards --- src/hierarchies/priors/abstract_prior_model.h | 6 +++--- src/hierarchies/priors/base_prior_model.h | 6 +++--- src/hierarchies/priors/hyperparams.h | 6 +++--- src/hierarchies/priors/nig_prior_model.h | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 5ccc7ddfc..cf76c2832 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -1,5 +1,5 @@ -#ifndef BAYESMIX_HIERARCHIES_ABSTRACT_PRIORMODEL_H_ -#define BAYESMIX_HIERARCHIES_ABSTRACT_PRIORMODEL_H_ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_ABSTRACT_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_ABSTRACT_PRIOR_MODEL_H_ #include @@ -42,4 +42,4 @@ class AbstractPriorModel { virtual void initialize_hypers() = 0; }; -#endif // BAYESMIX_HIERARCHIES_ABSTRACT_PRIORMODEL_H_ +#endif // BAYESMIX_HIERARCHIES_PRIORS_ABSTRACT_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index afc21a0d7..5627f78bf 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -1,5 +1,5 @@ -#ifndef BAYESMIX_HIERARCHIES_BASE_PRIORMODEL_H_ -#define BAYESMIX_HIERARCHIES_BASE_PRIORMODEL_H_ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_BASE_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_BASE_PRIOR_MODEL_H_ #include @@ -105,4 +105,4 @@ void BasePriorModel::check_prior_is_set() const { } } -#endif // BAYESMIX_HIERARCHIES_BASE_PRIORMODEL_H_ +#endif // BAYESMIX_HIERARCHIES_PRIORS_BASE_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h index b338aa1fe..d91806a42 100644 --- a/src/hierarchies/priors/hyperparams.h +++ b/src/hierarchies/priors/hyperparams.h @@ -1,5 +1,5 @@ -#ifndef BAYESMIX_HIERARCHIES_PRIORMODEL_HYPERPARAMS_H_ -#define BAYESMIX_HIERARCHIES_PRIORMODEL_HYPERPARAMS_H_ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_ #include @@ -27,4 +27,4 @@ struct MNIG { } // namespace Hyperparams -#endif // BAYESMIX_HIERARCHIES_PRIORMODEL_HYPERPARAMS_H_ +#endif // BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_ diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index 1ffda2007..8f80f22eb 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -1,5 +1,5 @@ -#ifndef BAYESMIX_HIERARCHIES_NIG_PRIOR_MODEL_H_ -#define BAYESMIX_HIERARCHIES_NIG_PRIOR_MODEL_H_ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_NIG_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_NIG_PRIOR_MODEL_H_ // #include @@ -38,4 +38,4 @@ class NIGPriorModel : public BasePriorModel Date: Thu, 27 Jan 2022 09:14:49 +0100 Subject: [PATCH 105/317] Add NNxIG hierarchy include --- src/includes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/includes.h b/src/includes.h index 39e03bea3..ae41aeefe 100644 --- a/src/includes.h +++ b/src/includes.h @@ -13,7 +13,7 @@ #include "hierarchies/load_hierarchies.h" #include "hierarchies/nnig_hierarchy.h" // #include "hierarchies/nnw_hierarchy.h" -// #include "hierarchies/nnxig_hierarchy.h" +#include "hierarchies/nnxig_hierarchy.h" #include "mixings/dirichlet_mixing.h" #include "mixings/load_mixings.h" #include "mixings/logit_sb_mixing.h" From 7671daceceaecdb74313ae71cddd976a78f321ea Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 27 Jan 2022 14:06:22 +0100 Subject: [PATCH 106/317] gradient test --- test/gradient.cc | 56 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 test/gradient.cc diff --git a/test/gradient.cc b/test/gradient.cc new file mode 100644 index 000000000..516dceefa --- /dev/null +++ b/test/gradient.cc @@ -0,0 +1,56 @@ +#include + +#include + + +class fbase { + public: + virtual double lpdf(const Eigen::VectorXd & x) = 0; +}; + + +class f1: public fbase { + protected: + double y; + public: + f1() = default; + f1(double y): y(y) {} + + template + T lpdf(const Eigen::Matrix & x) const { + return 0.5 * x.squaredNorm() * y; + } + + double lpdf(const Eigen::Matrix & x) { + return this->lpdf(x); + } +}; + +template +struct target_lpdf { + F f; + + template + T operator() (const Eigen::Matrix & x) const { + return f.lpdf(x); + } +}; + + + +TEST(gradient, quadratic_function) { + + Eigen::VectorXd out; + Eigen::VectorXd x(5); + x << 1.0, 2.0, 3.0, 4.0, 5.0; + target_lpdf target_function; + target_function.f = f1(5.0); + double y; + stan::math::gradient(target_function, x, y, out); + + for (int i=0; i < 5; i++) { + ASSERT_DOUBLE_EQ(out(i), 5 * x(i)); + } +} + + From 5bce13ffea0bfd0b7758eae854b8724ab41a04b8 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 27 Jan 2022 14:26:33 +0100 Subject: [PATCH 107/317] checkpoint --- .../likelihoods/abstract_likelihood.h | 2 +- .../likelihoods/uni_norm_likelihood.cc | 21 +++--- .../likelihoods/uni_norm_likelihood.h | 19 ++++- test/gradient.cc | 70 +++++++++---------- 4 files changed, 60 insertions(+), 52 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index ec3d5ce75..413937dd5 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -38,7 +38,7 @@ class AbstractLikelihood { //! Usually, some kind of transformation is required from the unconstrained //! parameterization to the actual parameterization. virtual double cluster_lpdf_from_unconstrained( - Eigen::VectorXd unconstrained_params) { + Eigen::VectorXd unconstrained_params) const { throw std::runtime_error( "cluster_lpdf_from_unconstrained() not yet implemented"); } diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index ce5419116..bda296076 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -36,13 +36,14 @@ void UniNormLikelihood::clear_summary_statistics() { data_sum_squares = 0; } -double UniNormLikelihood::cluster_lpdf_from_unconstrained( - Eigen::VectorXd unconstrained_params) { - assert(unconstrained_params.size() == 2); - double mean = unconstrained_params(0); - double var = std::exp(unconstrained_params(1)); - double out = -(data_sum_squares - 2 * mean * data_sum + card * mean * mean) / - (2 * var); - out -= card * 0.5 * std::log(stan::math::TWO_PI * var); - return out; -} +// double UniNormLikelihood::cluster_lpdf_from_unconstrained( +// Eigen::VectorXd unconstrained_params) { +// assert(unconstrained_params.size() == 2); +// double mean = unconstrained_params(0); +// double var = std::exp(unconstrained_params(1)); +// double out = -(data_sum_squares - 2 * mean * data_sum + card * mean * +// mean) / +// (2 * var); +// out -= card * 0.5 * std::log(stan::math::TWO_PI * var); +// return out; +// } diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 5ce0ed9c2..bc3108283 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -3,9 +3,8 @@ #include -#include #include -#include +#include #include #include "algorithm_state.pb.h" @@ -25,9 +24,23 @@ class UniNormLikelihood double get_data_sum() const { return data_sum; }; double get_data_sum_squares() const { return data_sum_squares; }; + template + T cluster_lpdf_from_unconstrained( + const Eigen::Matrix &unconstrained_params) const { + assert(unconstrained_params.size() == 2); + T mean = unconstrained_params(0); + T var = stan::math::positive_constrain(unconstrained_params(1)); + T out = -(data_sum_squares - 2 * mean * data_sum + card * mean * mean) / + (2 * var); + out -= card * 0.5 * std::log(stan::math::TWO_PI * var); + return out; + } + // The unconstrained parameters are mean and log(var) double cluster_lpdf_from_unconstrained( - Eigen::VectorXd unconstrained_params) override; + Eigen::VectorXd unconstrained_params) const override { + return this->cluster_lpdf_from_unconstrained(unconstrained_params); + } protected: std::shared_ptr get_state_proto() diff --git a/test/gradient.cc b/test/gradient.cc index 516dceefa..3ca4c1a13 100644 --- a/test/gradient.cc +++ b/test/gradient.cc @@ -2,55 +2,49 @@ #include - class fbase { - public: - virtual double lpdf(const Eigen::VectorXd & x) = 0; + public: + virtual double lpdf(const Eigen::VectorXd& x) = 0; }; +class f1 : public fbase { + protected: + double y; -class f1: public fbase { - protected: - double y; - public: - f1() = default; - f1(double y): y(y) {} + public: + f1() = default; + f1(double y) : y(y) {} - template - T lpdf(const Eigen::Matrix & x) const { - return 0.5 * x.squaredNorm() * y; - } + template + T lpdf(const Eigen::Matrix& x) const { + return 0.5 * x.squaredNorm() * y; + } - double lpdf(const Eigen::Matrix & x) { - return this->lpdf(x); - } + double lpdf(const Eigen::Matrix& x) { + return this->lpdf(x); + } }; -template +template struct target_lpdf { - F f; + F f; - template - T operator() (const Eigen::Matrix & x) const { - return f.lpdf(x); - } + template + T operator()(const Eigen::Matrix& x) const { + return f.lpdf(x); + } }; - - TEST(gradient, quadratic_function) { - - Eigen::VectorXd out; - Eigen::VectorXd x(5); - x << 1.0, 2.0, 3.0, 4.0, 5.0; - target_lpdf target_function; - target_function.f = f1(5.0); - double y; - stan::math::gradient(target_function, x, y, out); - - for (int i=0; i < 5; i++) { - ASSERT_DOUBLE_EQ(out(i), 5 * x(i)); - } + Eigen::VectorXd out; + Eigen::VectorXd x(5); + x << 1.0, 2.0, 3.0, 4.0, 5.0; + target_lpdf target_function; + target_function.f = f1(5.0); + double y; + stan::math::gradient(target_function, x, y, out); + + for (int i = 0; i < 5; i++) { + ASSERT_DOUBLE_EQ(out(i), 5 * x(i)); + } } - - From c30fdf0f1b9b9619b9ef2163862bcf4bf97c0f31 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 27 Jan 2022 15:45:35 +0100 Subject: [PATCH 108/317] first draft with AD --- src/hierarchies/base_hierarchy.h | 398 +++++------------- src/hierarchies/updaters/abstract_updater.h | 4 +- src/hierarchies/updaters/metropolis_updater.h | 36 +- .../updaters/random_walk_updater.h | 10 +- 4 files changed, 139 insertions(+), 309 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index b6050d94e..9adca2835 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -3,17 +3,20 @@ #include -#include #include #include #include #include +#include #include "abstract_hierarchy.h" #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" #include "src/utils/rng.h" +template +class target_lpdf_unconstrained; + //! Base template class for a hierarchy object. //! This class is a templatized version of, and derived from, the @@ -32,6 +35,10 @@ template class BaseHierarchy : public AbstractHierarchy { + + template + friend class target_lpdf_unconstrained; + protected: std::shared_ptr like = std::make_shared(); std::shared_ptr prior = std::make_shared(); @@ -165,11 +172,7 @@ class BaseHierarchy : public AbstractHierarchy { // like->set_card(card); }; - void sample_full_cond(bool update_params = false) override { - // int card = like->get_card(); - updater->draw(*like, *prior, update_params); - // like->set_card(card); - }; + void sample_full_cond(bool update_params = false) override; void sample_full_cond( const Eigen::MatrixXd &data, @@ -289,289 +292,104 @@ class BaseHierarchy : public AbstractHierarchy { } }; -// //! Returns an independent, data-less copy of this object -// virtual std::shared_ptr clone() const override { -// auto out = std::make_shared(static_cast(*this)); out->clear_data(); out->clear_summary_statistics(); return -// out; -// } - -// //! Evaluates the log-likelihood of data in a grid of points -// //! @param data Grid of points (by row) which are to be evaluated -// //! @param covariates (Optional) covariate vectors associated to data -// //! @return The evaluation of the lpdf -// virtual Eigen::VectorXd like_lpdf_grid( -// const Eigen::MatrixXd &data, -// const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, -// 0)) const -// override; - -// //! Generates new state values from the centering prior distribution -// void sample_prior() override { -// state = static_cast(this)->draw(*hypers); -// } - -// //! Overloaded version of sample_full_cond(bool), mainly used for -// debugging virtual void sample_full_cond( -// const Eigen::MatrixXd &data, -// const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override; - -// //! Returns the current cardinality of the cluster -// int get_card() const override { return card; } - -// //! Returns the logarithm of the current cardinality of the cluster -// double get_log_card() const override { return log_card; } - -// //! Returns the indexes of data points belonging to this cluster -// std::set get_data_idx() const override { return cluster_data_idx; } - -// //! Returns a pointer to the Protobuf message of the prior of this cluster -// virtual google::protobuf::Message *get_mutable_prior() override { -// if (prior == nullptr) { -// create_empty_prior(); -// } -// return prior.get(); -// } - -// //! Writes current state to a Protobuf message by pointer -// void write_state_to_proto(google::protobuf::Message *out) const override; - -// //! Writes current values of the hyperparameters to a Protobuf message by -// //! pointer -// void write_hypers_to_proto(google::protobuf::Message *out) const override; - -// //! Returns the struct of the current state -// State get_state() const { return state; } - -// //! Returns the struct of the current prior hyperparameters -// Hyperparams get_hypers() const { return *hypers; } - -// //! Returns the struct of the current posterior hyperparameters -// Hyperparams get_posterior_hypers() const { return posterior_hypers; } - -// //! Adds a datum and its index to the hierarchy -// void add_datum( -// const int id, const Eigen::RowVectorXd &datum, -// const bool update_params = false, -// const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; - -// //! Removes a datum and its index from the hierarchy -// void remove_datum( -// const int id, const Eigen::RowVectorXd &datum, -// const bool update_params = false, -// const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; - -// //! Main function that initializes members to appropriate values -// void initialize() override { -// hypers = std::make_shared(); -// check_prior_is_set(); -// initialize_hypers(); -// initialize_state(); -// posterior_hypers = *hypers; -// clear_data(); -// clear_summary_statistics(); -// } - -// protected: -// //! Raises an error if the prior pointer is not initialized -// void check_prior_is_set() const { -// if (prior == nullptr) { -// throw std::invalid_argument("Hierarchy prior was not provided"); -// } -// } - -// //! Re-initializes the prior of the hierarchy to a newly created object -// void create_empty_prior() { prior.reset(new Prior); } - -// //! Sets the cardinality of the cluster -// void set_card(const int card_) { -// card = card_; -// log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); -// } - -// //! Writes current state to a Protobuf message and return a shared_ptr -// //! New hierarchies have to first modify the field 'oneof val' in the -// //! AlgoritmState::ClusterState message by adding the appropriate type -// virtual std::shared_ptr -// get_state_proto() const = 0; - -// //! Initializes state parameters to appropriate values -// virtual void initialize_state() = 0; - -// //! Writes current value of hyperparameters to a Protobuf message and -// //! return a shared_ptr. -// //! New hierarchies have to first modify the field 'oneof val' in the -// //! AlgoritmState::HierarchyHypers message by adding the appropriate type -// virtual std::shared_ptr -// get_hypers_proto() const = 0; - -// //! Initializes hierarchy hyperparameters to appropriate values -// virtual void initialize_hypers() = 0; - -// //! Resets cardinality and indexes of data in this cluster -// void clear_data() { -// set_card(0); -// cluster_data_idx = std::set(); -// } - -// virtual void clear_summary_statistics() = 0; - -// //! Down-casts the given generic proto message to a ClusterState proto -// bayesmix::AlgorithmState::ClusterState *downcast_state( -// google::protobuf::Message *state_) const { -// return google::protobuf::internal::down_cast< -// bayesmix::AlgorithmState::ClusterState *>(state_); -// } - -// //! Down-casts the given generic proto message to a ClusterState proto -// const bayesmix::AlgorithmState::ClusterState &f( -// const google::protobuf::Message &state_) const { -// return google::protobuf::internal::down_cast< -// const bayesmix::AlgorithmState::ClusterState &>(state_); -// } - -// //! Down-casts the given generic proto message to a HierarchyHypers proto -// bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( -// google::protobuf::Message *state_) const { -// return google::protobuf::internal::down_cast< -// bayesmix::AlgorithmState::HierarchyHypers *>(state_); -// } - -// //! Down-casts the given generic proto message to a HierarchyHypers proto -// const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( -// const google::protobuf::Message &state_) const { -// return google::protobuf::internal::down_cast< -// const bayesmix::AlgorithmState::HierarchyHypers &>(state_); -// } - -// //! Container for state values -// State state; - -// //! Container for prior hyperparameters values -// std::shared_ptr hypers; - -// //! Container for posterior hyperparameters values -// Hyperparams posterior_hypers; - -// //! Pointer to a Protobuf prior object for this class -// std::shared_ptr prior; - -// //! Set of indexes of data points belonging to this cluster -// std::set cluster_data_idx; - -// //! Current cardinality of this cluster -// int card = 0; - -// //! Logarithm of current cardinality of this cluster -// double log_card = stan::math::NEGATIVE_INFTY; -// }; - -// template void BaseHierarchy::add_datum( -// const int id, const Eigen::RowVectorXd &datum, -// const bool update_params /*= false*/, -// const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) { -// assert(cluster_data_idx.find(id) == cluster_data_idx.end()); -// card += 1; -// log_card = std::log(card); -// static_cast(this)->update_ss(datum, covariate, true); -// cluster_data_idx.insert(id); -// if (update_params) { -// static_cast(this)->save_posterior_hypers(); -// } -// } - -// template void BaseHierarchy::remove_datum( -// const int id, const Eigen::RowVectorXd &datum, -// const bool update_params /*= false*/, -// const Eigen::RowVectorXd &covariate /* = Eigen::RowVectorXd(0)*/) { -// static_cast(this)->update_ss(datum, covariate, false); -// set_card(card - 1); -// auto it = cluster_data_idx.find(id); -// assert(it != cluster_data_idx.end()); -// cluster_data_idx.erase(it); -// if (update_params) { -// static_cast(this)->save_posterior_hypers(); -// } -// } - -// template void BaseHierarchy::write_state_to_proto( -// google::protobuf::Message *out) const { -// std::shared_ptr state_ = -// get_state_proto(); -// auto *out_cast = downcast_state(out); -// out_cast->CopyFrom(*state_.get()); -// out_cast->set_cardinality(card); -// } - -// template void BaseHierarchy::write_hypers_to_proto( -// google::protobuf::Message *out) const { -// std::shared_ptr hypers_ = -// get_hypers_proto(); -// auto *out_cast = downcast_hypers(out); -// out_cast->CopyFrom(*hypers_.get()); -// } - -// template Eigen::VectorXd BaseHierarchy::like_lpdf_grid( -// const Eigen::MatrixXd &data, -// const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { -// Eigen::VectorXd lpdf(data.rows()); -// if (covariates.cols() == 0) { -// // Pass null value as covariate -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->get_like_lpdf( -// data.row(i), Eigen::RowVectorXd(0)); -// } -// } else if (covariates.rows() == 1) { -// // Use unique covariate -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->get_like_lpdf( -// data.row(i), covariates.row(0)); -// } -// } else { -// // Use different covariates -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->get_like_lpdf( -// data.row(i), covariates.row(i)); -// } -// } -// return lpdf; -// } - -// template void BaseHierarchy::sample_full_cond( -// const Eigen::MatrixXd &data, -// const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { -// clear_data(); -// clear_summary_statistics(); -// if (covariates.cols() == 0) { -// // Pass null value as covariate -// for (int i = 0; i < data.rows(); i++) { -// static_cast(this)->add_datum(i, data.row(i), false, -// Eigen::RowVectorXd(0)); -// } -// } else if (covariates.rows() == 1) { -// // Use unique covariate -// for (int i = 0; i < data.rows(); i++) { -// static_cast(this)->add_datum(i, data.row(i), false, -// covariates.row(0)); -// } -// } else { -// // Use different covariates -// for (int i = 0; i < data.rows(); i++) { -// static_cast(this)->add_datum(i, data.row(i), false, -// covariates.row(i)); -// } -// } -// static_cast(this)->sample_full_cond(true); -// } +template +class target_lpdf_unconstrained { + protected: + const DerivedHierarchy &parent; + + public: + target_lpdf_unconstrained(const DerivedHierarchy &p): parent(p) {} + + template + T operator()(const Eigen::Matrix &x) const { + return parent.like->clus_lpdf_from_unconstrained(x) + + parent.prior->lpdf_from_unconstrained(x); + } +}; + +template +void BaseHierarchy::sample_full_cond( + bool update_params) { + target_lpdf_unconstrained target(static_cast(*this)); + updater->draw(*like, *prior, update_params, target); +}; + + // template void BaseHierarchy::write_state_to_proto( + // google::protobuf::Message *out) const { + // std::shared_ptr state_ = + // get_state_proto(); + // auto *out_cast = downcast_state(out); + // out_cast->CopyFrom(*state_.get()); + // out_cast->set_cardinality(card); + // } + + // template void BaseHierarchy::write_hypers_to_proto( + // google::protobuf::Message *out) const { + // std::shared_ptr hypers_ = + // get_hypers_proto(); + // auto *out_cast = downcast_hypers(out); + // out_cast->CopyFrom(*hypers_.get()); + // } + + // template Eigen::VectorXd BaseHierarchy::like_lpdf_grid( + // const Eigen::MatrixXd &data, + // const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { + // Eigen::VectorXd lpdf(data.rows()); + // if (covariates.cols() == 0) { + // // Pass null value as covariate + // for (int i = 0; i < data.rows(); i++) { + // lpdf(i) = static_cast(this)->get_like_lpdf( + // data.row(i), Eigen::RowVectorXd(0)); + // } + // } else if (covariates.rows() == 1) { + // // Use unique covariate + // for (int i = 0; i < data.rows(); i++) { + // lpdf(i) = static_cast(this)->get_like_lpdf( + // data.row(i), covariates.row(0)); + // } + // } else { + // // Use different covariates + // for (int i = 0; i < data.rows(); i++) { + // lpdf(i) = static_cast(this)->get_like_lpdf( + // data.row(i), covariates.row(i)); + // } + // } + // return lpdf; + // } + + // template void BaseHierarchy::sample_full_cond( + // const Eigen::MatrixXd &data, + // const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { + // clear_data(); + // clear_summary_statistics(); + // if (covariates.cols() == 0) { + // // Pass null value as covariate + // for (int i = 0; i < data.rows(); i++) { + // static_cast(this)->add_datum(i, data.row(i), false, + // Eigen::RowVectorXd(0)); + // } + // } else if (covariates.rows() == 1) { + // // Use unique covariate + // for (int i = 0; i < data.rows(); i++) { + // static_cast(this)->add_datum(i, data.row(i), false, + // covariates.row(0)); + // } + // } else { + // // Use different covariates + // for (int i = 0; i < data.rows(); i++) { + // static_cast(this)->add_datum(i, data.row(i), false, + // covariates.row(i)); + // } + // } + // static_cast(this)->sample_full_cond(true); + // } #endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 0145d04fd..e5c6e283d 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -8,8 +8,8 @@ class AbstractUpdater { public: virtual ~AbstractUpdater() = default; virtual bool is_conjugate() const { return false; }; - virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, - bool update_params) = 0; +// virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, +// bool update_params) = 0; virtual void compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) { throw std::runtime_error("compute_posterior_hypers not implemented"); diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h index b328cf7d4..4de5a3902 100644 --- a/src/hierarchies/updaters/metropolis_updater.h +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -3,32 +3,42 @@ #include "abstract_updater.h" +template class MetropolisUpdater : public AbstractUpdater { public: + + template void draw(AbstractLikelihood &like, AbstractPriorModel &prior, - bool update_params) override { + bool update_params, F& target_lpdf) { Eigen::VectorXd curr_state = like.get_unconstrained_state(); - Eigen::VectorXd prop_state = sample_proposal(curr_state, like, prior); + Eigen::VectorXd prop_state = static_cast(this)->sample_proposal( + curr_state, like, prior, target_lpdf); double log_arate = like.cluster_lpdf_from_unconstrained(prop_state) - like.cluster_lpdf_from_unconstrained(curr_state) + - proposal_lpdf(curr_state, prop_state, like, prior) - - proposal_lpdf(prop_state, curr_state, like, prior); + static_cast(this)->proposal_lpdf( + curr_state, prop_state, like, prior, target_lpdf) - + static_cast(this)->proposal_lpdf( + prop_state, curr_state, like, prior, target_lpdf); auto &rng = bayesmix::Rng::Instance().get(); if (std::log(stan::math::uniform_rng(0, 1, rng)) < log_arate) { like.set_state_from_unconstrained(prop_state); } } - - virtual Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, - AbstractLikelihood &like, - AbstractPriorModel &prior) = 0; - - virtual double proposal_lpdf(Eigen::VectorXd prop_state, - Eigen::VectorXd curr_state, - AbstractLikelihood &like, - AbstractPriorModel &prior) = 0; + +// template +// virtual Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, +// AbstractLikelihood &like, +// AbstractPriorModel &prior, +// F& target_lpdf) = 0; + +// template +// virtual double proposal_lpdf(Eigen::VectorXd prop_state, +// Eigen::VectorXd curr_state, +// AbstractLikelihood &like, +// AbstractPriorModel &prior, +// F& target_lpdf) = 0; }; #endif diff --git a/src/hierarchies/updaters/random_walk_updater.h b/src/hierarchies/updaters/random_walk_updater.h index 0f39edfd9..0f74bb247 100644 --- a/src/hierarchies/updaters/random_walk_updater.h +++ b/src/hierarchies/updaters/random_walk_updater.h @@ -3,7 +3,7 @@ #include "metropolis_updater.h" -class RandomWalkUpdater : public MetropolisUpdater { +class RandomWalkUpdater : public MetropolisUpdater { protected: double step_size; @@ -13,9 +13,10 @@ class RandomWalkUpdater : public MetropolisUpdater { RandomWalkUpdater(double step_size) : step_size(step_size) {} + template Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, AbstractLikelihood &like, - AbstractPriorModel &prior) override { + AbstractPriorModel &prior, F& target_lpdf) { Eigen::VectorXd step(curr_state.size()); auto &rng = bayesmix::Rng::Instance().get(); for (int i = 0; i < curr_state.size(); i++) { @@ -23,10 +24,11 @@ class RandomWalkUpdater : public MetropolisUpdater { } return curr_state + step; } - + + template double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, AbstractLikelihood &like, - AbstractPriorModel &prior) override { + AbstractPriorModel &prior, F& target_lpdf) { double out; for (int i = 0; i < prop_state.size(); i++) { out += stan::math::normal_lpdf(prop_state(i), curr_state(i), step_size); From a1b0619035ad9f7f4652a4e5e6394089ea8f6efd Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 27 Jan 2022 16:24:25 +0100 Subject: [PATCH 109/317] mala updater working --- src/hierarchies/updaters/abstract_updater.h | 4 +- src/hierarchies/updaters/mala_updater.h | 63 +++++++++++++++++++ src/hierarchies/updaters/metropolis_updater.h | 38 +++++------ .../updaters/random_walk_updater.h | 10 +-- 4 files changed, 89 insertions(+), 26 deletions(-) create mode 100644 src/hierarchies/updaters/mala_updater.h diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index e5c6e283d..9a3a78d57 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -8,8 +8,8 @@ class AbstractUpdater { public: virtual ~AbstractUpdater() = default; virtual bool is_conjugate() const { return false; }; -// virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, -// bool update_params) = 0; + // virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + // bool update_params) = 0; virtual void compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) { throw std::runtime_error("compute_posterior_hypers not implemented"); diff --git a/src/hierarchies/updaters/mala_updater.h b/src/hierarchies/updaters/mala_updater.h new file mode 100644 index 000000000..f407f3b78 --- /dev/null +++ b/src/hierarchies/updaters/mala_updater.h @@ -0,0 +1,63 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_MALA_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_MALA_UPDATER_H_ + +#include "metropolis_updater.h" + +#include + +class MalaUpdater : public MetropolisUpdater { + protected: + double step_size; + + public: + MalaUpdater() = default; + ~MalaUpdater() = default; + + MalaUpdater(double step_size) : step_size(step_size) {} + + template + Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, + AbstractLikelihood &like, + AbstractPriorModel &prior, + F& target_lpdf) { + + Eigen::VectorXd noise(curr_state.size()); + auto &rng = bayesmix::Rng::Instance().get(); + double noise_scale = std::sqrt(2 * step_size); + for (int i = 0; i < curr_state.size(); i++) { + noise(i) = stan::math::normal_rng(0, noise_scale, rng); + } + Eigen::VectorXd grad; + double tmp; + stan::math::gradient(target_lpdf, curr_state, tmp, grad); + return curr_state + step_size * grad + noise; + } + + template + double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, + AbstractLikelihood &like, + AbstractPriorModel &prior, + F& target_lpdf) { + double out; + Eigen::VectorXd grad; + double tmp; + stan::math::gradient(target_lpdf, curr_state, tmp, grad); + Eigen::VectorXd mean = curr_state + step_size * grad; + + double noise_scale = std::sqrt(2 * step_size); + + for (int i = 0; i < prop_state.size(); i++) { + out += stan::math::normal_lpdf(prop_state(i), mean(i), noise_scale); + } + return out; + } + + std::shared_ptr clone() const { + auto out = std::make_shared( + static_cast(*this)); + return out; + } + +}; + +#endif diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h index 4de5a3902..70f91b9f4 100644 --- a/src/hierarchies/updaters/metropolis_updater.h +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -6,19 +6,19 @@ template class MetropolisUpdater : public AbstractUpdater { public: - template void draw(AbstractLikelihood &like, AbstractPriorModel &prior, - bool update_params, F& target_lpdf) { + bool update_params, F &target_lpdf) { Eigen::VectorXd curr_state = like.get_unconstrained_state(); - Eigen::VectorXd prop_state = static_cast(this)->sample_proposal( - curr_state, like, prior, target_lpdf); + Eigen::VectorXd prop_state = + static_cast(this)->sample_proposal( + curr_state, like, prior, target_lpdf); double log_arate = like.cluster_lpdf_from_unconstrained(prop_state) - like.cluster_lpdf_from_unconstrained(curr_state) + - static_cast(this)->proposal_lpdf( + static_cast(this)->proposal_lpdf( curr_state, prop_state, like, prior, target_lpdf) - - static_cast(this)->proposal_lpdf( + static_cast(this)->proposal_lpdf( prop_state, curr_state, like, prior, target_lpdf); auto &rng = bayesmix::Rng::Instance().get(); @@ -26,19 +26,19 @@ class MetropolisUpdater : public AbstractUpdater { like.set_state_from_unconstrained(prop_state); } } - -// template -// virtual Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, -// AbstractLikelihood &like, -// AbstractPriorModel &prior, -// F& target_lpdf) = 0; - -// template -// virtual double proposal_lpdf(Eigen::VectorXd prop_state, -// Eigen::VectorXd curr_state, -// AbstractLikelihood &like, -// AbstractPriorModel &prior, -// F& target_lpdf) = 0; + + // template + // virtual Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, + // AbstractLikelihood &like, + // AbstractPriorModel &prior, + // F& target_lpdf) = 0; + + // template + // virtual double proposal_lpdf(Eigen::VectorXd prop_state, + // Eigen::VectorXd curr_state, + // AbstractLikelihood &like, + // AbstractPriorModel &prior, + // F& target_lpdf) = 0; }; #endif diff --git a/src/hierarchies/updaters/random_walk_updater.h b/src/hierarchies/updaters/random_walk_updater.h index 0f74bb247..b9559a743 100644 --- a/src/hierarchies/updaters/random_walk_updater.h +++ b/src/hierarchies/updaters/random_walk_updater.h @@ -13,10 +13,10 @@ class RandomWalkUpdater : public MetropolisUpdater { RandomWalkUpdater(double step_size) : step_size(step_size) {} - template + template Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, AbstractLikelihood &like, - AbstractPriorModel &prior, F& target_lpdf) { + AbstractPriorModel &prior, F &target_lpdf) { Eigen::VectorXd step(curr_state.size()); auto &rng = bayesmix::Rng::Instance().get(); for (int i = 0; i < curr_state.size(); i++) { @@ -24,11 +24,11 @@ class RandomWalkUpdater : public MetropolisUpdater { } return curr_state + step; } - + template double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, - AbstractLikelihood &like, - AbstractPriorModel &prior, F& target_lpdf) { + AbstractLikelihood &like, AbstractPriorModel &prior, + F &target_lpdf) { double out; for (int i = 0; i < prop_state.size(); i++) { out += stan::math::normal_lpdf(prop_state(i), curr_state(i), step_size); From 41745ed857e754f8b7d18d9bcbffdcd8f10bfd4a Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 27 Jan 2022 16:24:48 +0100 Subject: [PATCH 110/317] ad working --- benchmarks/lpd_grid.cc | 5 +-- benchmarks/nnw_marg_lpdf.cc | 2 +- examples/gamma_hierarchy/gamma_gamma_hier.h | 7 ++-- python/notebooks/gaussian_mix_uni.ipynb | 4 +- run_mcmc.cc | 2 +- src/algorithms/base_algorithm.cc | 2 +- src/algorithms/base_algorithm.h | 2 +- src/algorithms/conditional_algorithm.cc | 2 +- src/algorithms/conditional_algorithm.h | 2 +- src/algorithms/marginal_algorithm.cc | 2 +- src/algorithms/marginal_algorithm.h | 2 +- src/algorithms/neal2_algorithm.cc | 2 +- src/algorithms/neal2_algorithm.h | 2 +- src/algorithms/neal3_algorithm.cc | 2 +- src/algorithms/neal8_algorithm.cc | 2 +- src/algorithms/neal8_algorithm.h | 2 +- src/algorithms/semihdp_sampler.h | 2 +- src/hierarchies/abstract_hierarchy.h | 2 +- src/hierarchies/base_hierarchy.h | 9 ++-- .../likelihoods/abstract_likelihood.h | 2 +- src/hierarchies/likelihoods/base_likelihood.h | 2 +- src/hierarchies/likelihoods/states.h | 42 +++++++++++++++---- .../likelihoods/uni_norm_likelihood.h | 2 +- src/hierarchies/nnig_hierarchy.h | 7 ++-- src/hierarchies/priors/abstract_prior_model.h | 2 +- src/hierarchies/priors/base_prior_model.h | 2 +- src/hierarchies/priors/hyperparams.h | 2 +- src/hierarchies/priors/nig_prior_model.cc | 7 ---- src/hierarchies/priors/nig_prior_model.h | 21 +++++++++- src/hierarchies/updaters/mala_updater.h | 22 ++++------ src/mixings/abstract_mixing.h | 2 +- src/mixings/base_mixing.h | 2 +- src/mixings/dirichlet_mixing.h | 2 +- src/mixings/logit_sb_mixing.cc | 2 +- src/mixings/logit_sb_mixing.h | 2 +- src/mixings/pityor_mixing.cc | 2 +- src/mixings/pityor_mixing.h | 2 +- src/mixings/truncated_sb_mixing.cc | 2 +- src/mixings/truncated_sb_mixing.h | 2 +- src/utils/cluster_utils.cc | 2 +- src/utils/cluster_utils.h | 2 +- src/utils/distributions.cc | 2 +- src/utils/distributions.h | 2 +- src/utils/eigen_utils.h | 2 +- src/utils/io_utils.cc | 2 +- src/utils/io_utils.h | 2 +- src/utils/proto_utils.cc | 2 +- src/utils/proto_utils.h | 2 +- test/collectors.cc | 2 +- test/distributions.cc | 2 +- test/eigen_utils.cc | 2 +- test/hierarchies.cc | 2 +- test/likelihoods.cc | 2 +- test/logit_sb.cc | 2 +- test/lpdf.cc | 2 +- test/mfm_mixing.cc | 4 +- test/prior_models.cc | 2 +- test/proto_utils.cc | 2 +- test/rng.cc | 2 +- test/semi_hdp.cc | 2 +- 60 files changed, 129 insertions(+), 99 deletions(-) diff --git a/benchmarks/lpd_grid.cc b/benchmarks/lpd_grid.cc index d297f5943..507b20d0b 100644 --- a/benchmarks/lpd_grid.cc +++ b/benchmarks/lpd_grid.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include "src/utils/distributions.h" #include "utils.h" @@ -37,7 +37,6 @@ Eigen::VectorXd lpdf_naive(const Eigen::MatrixXd &x, return out; } - Eigen::VectorXd lpdf_fully_optimized(const Eigen::MatrixXd &x, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, @@ -84,7 +83,6 @@ static void BM_gauss_lpdf_naive(benchmark::State &state) { } } - static void BM_gauss_lpdf_fully_optimized(benchmark::State &state) { int dim = state.range(0); Eigen::VectorXd mean = Eigen::VectorXd::Zero(dim); @@ -99,7 +97,6 @@ static void BM_gauss_lpdf_fully_optimized(benchmark::State &state) { } } - BENCHMARK(BM_gauss_lpdf_cov)->RangeMultiplier(2)->Range(2, 2 << 4); BENCHMARK(BM_gauss_lpdf_cov)->RangeMultiplier(2)->Range(2, 2 << 4); BENCHMARK(BM_gauss_lpdf_naive)->RangeMultiplier(2)->Range(2, 2 << 4); diff --git a/benchmarks/nnw_marg_lpdf.cc b/benchmarks/nnw_marg_lpdf.cc index 696a8a198..a99314692 100644 --- a/benchmarks/nnw_marg_lpdf.cc +++ b/benchmarks/nnw_marg_lpdf.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include "utils.h" diff --git a/examples/gamma_hierarchy/gamma_gamma_hier.h b/examples/gamma_hierarchy/gamma_gamma_hier.h index 2e31a359b..0b80d7867 100644 --- a/examples/gamma_hierarchy/gamma_gamma_hier.h +++ b/examples/gamma_hierarchy/gamma_gamma_hier.h @@ -4,12 +4,13 @@ #include #include #include -#include "hierarchy_prior.pb.h" -#include #include +#include #include +#include "hierarchy_prior.pb.h" + namespace GammaGamma { //! Custom container for State values struct State { @@ -97,7 +98,7 @@ class GammaGammaHierarchy set_card(statecast.cardinality()); } - std::shared_ptr get_state_proto() + std::shared_ptr get_state_proto() const override { bayesmix::Vector state_; state_.mutable_data()->Add(state.rate); diff --git a/python/notebooks/gaussian_mix_uni.ipynb b/python/notebooks/gaussian_mix_uni.ipynb index 89e6ea598..2abba91ec 100644 --- a/python/notebooks/gaussian_mix_uni.ipynb +++ b/python/notebooks/gaussian_mix_uni.ipynb @@ -256,7 +256,9 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "ciao" + ] }, { "cell_type": "markdown", diff --git a/run_mcmc.cc b/run_mcmc.cc index 93f00d637..6ad72ae43 100644 --- a/run_mcmc.cc +++ b/run_mcmc.cc @@ -180,7 +180,7 @@ int main(int argc, char *argv[]) { std::cout << "hier->prior: \n" << hier->get_mutable_prior()->DebugString() << std::endl; - auto updater = std::make_shared(0.25); + auto updater = std::make_shared(0.25); hier->set_updater(updater); hier->initialize(); diff --git a/src/algorithms/base_algorithm.cc b/src/algorithms/base_algorithm.cc index 41793f54d..ef8bba064 100644 --- a/src/algorithms/base_algorithm.cc +++ b/src/algorithms/base_algorithm.cc @@ -1,7 +1,7 @@ #include "base_algorithm.h" -#include #include +#include #include #include "algorithm_params.pb.h" diff --git a/src/algorithms/base_algorithm.h b/src/algorithms/base_algorithm.h index 060e60c82..5aab07a49 100644 --- a/src/algorithms/base_algorithm.h +++ b/src/algorithms/base_algorithm.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "algorithm_id.pb.h" diff --git a/src/algorithms/conditional_algorithm.cc b/src/algorithms/conditional_algorithm.cc index 30090013a..aebc2cdf5 100644 --- a/src/algorithms/conditional_algorithm.cc +++ b/src/algorithms/conditional_algorithm.cc @@ -1,7 +1,7 @@ #include "conditional_algorithm.h" -#include #include +#include #include "algorithm_state.pb.h" #include "base_algorithm.h" diff --git a/src/algorithms/conditional_algorithm.h b/src/algorithms/conditional_algorithm.h index 3b09585f5..9a3c408dd 100644 --- a/src/algorithms/conditional_algorithm.h +++ b/src/algorithms/conditional_algorithm.h @@ -1,8 +1,8 @@ #ifndef BAYESMIX_ALGORITHMS_CONDITIONAL_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_CONDITIONAL_ALGORITHM_H_ -#include #include +#include #include "base_algorithm.h" #include "src/collectors/base_collector.h" diff --git a/src/algorithms/marginal_algorithm.cc b/src/algorithms/marginal_algorithm.cc index 130a6be61..48d2a167c 100644 --- a/src/algorithms/marginal_algorithm.cc +++ b/src/algorithms/marginal_algorithm.cc @@ -1,8 +1,8 @@ #include "marginal_algorithm.h" -#include #include #include +#include #include "algorithm_state.pb.h" #include "base_algorithm.h" diff --git a/src/algorithms/marginal_algorithm.h b/src/algorithms/marginal_algorithm.h index 1691de143..9535b705d 100644 --- a/src/algorithms/marginal_algorithm.h +++ b/src/algorithms/marginal_algorithm.h @@ -1,8 +1,8 @@ #ifndef BAYESMIX_ALGORITHMS_MARGINAL_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_MARGINAL_ALGORITHM_H_ -#include #include +#include #include "base_algorithm.h" #include "src/collectors/base_collector.h" diff --git a/src/algorithms/neal2_algorithm.cc b/src/algorithms/neal2_algorithm.cc index 70a73b66c..ecd0c25f5 100644 --- a/src/algorithms/neal2_algorithm.cc +++ b/src/algorithms/neal2_algorithm.cc @@ -1,8 +1,8 @@ #include "neal2_algorithm.h" -#include #include #include +#include #include #include "algorithm_id.pb.h" diff --git a/src/algorithms/neal2_algorithm.h b/src/algorithms/neal2_algorithm.h index 7a73d9f91..1eb7f1dbc 100644 --- a/src/algorithms/neal2_algorithm.h +++ b/src/algorithms/neal2_algorithm.h @@ -1,8 +1,8 @@ #ifndef BAYESMIX_ALGORITHMS_NEAL2_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_NEAL2_ALGORITHM_H_ -#include #include +#include #include "algorithm_id.pb.h" #include "marginal_algorithm.h" diff --git a/src/algorithms/neal3_algorithm.cc b/src/algorithms/neal3_algorithm.cc index 7e209b81f..e7bf9712a 100644 --- a/src/algorithms/neal3_algorithm.cc +++ b/src/algorithms/neal3_algorithm.cc @@ -1,7 +1,7 @@ #include "neal3_algorithm.h" -#include #include +#include #include "hierarchy_id.pb.h" #include "mixing_id.pb.h" diff --git a/src/algorithms/neal8_algorithm.cc b/src/algorithms/neal8_algorithm.cc index 808fa7deb..457e6da3a 100644 --- a/src/algorithms/neal8_algorithm.cc +++ b/src/algorithms/neal8_algorithm.cc @@ -1,8 +1,8 @@ #include "neal8_algorithm.h" -#include #include #include +#include #include "algorithm_id.pb.h" #include "algorithm_state.pb.h" diff --git a/src/algorithms/neal8_algorithm.h b/src/algorithms/neal8_algorithm.h index 11ad0048f..30edc06ea 100644 --- a/src/algorithms/neal8_algorithm.h +++ b/src/algorithms/neal8_algorithm.h @@ -1,8 +1,8 @@ #ifndef BAYESMIX_ALGORITHMS_NEAL8_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_NEAL8_ALGORITHM_H_ -#include #include +#include #include #include "algorithm_id.pb.h" diff --git a/src/algorithms/semihdp_sampler.h b/src/algorithms/semihdp_sampler.h index e684c6dd1..3558a3ae9 100644 --- a/src/algorithms/semihdp_sampler.h +++ b/src/algorithms/semihdp_sampler.h @@ -1,9 +1,9 @@ #ifndef SRC_ALGORITHMS_SEMIHDP_SAMPLER_H #define SRC_ALGORITHMS_SEMIHDP_SAMPLER_H -#include #include #include +#include #include #include diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 8eee2d403..963ac12b8 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -3,11 +3,11 @@ #include -#include #include #include #include #include +#include #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 9adca2835..e7fe80785 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -35,7 +35,6 @@ class target_lpdf_unconstrained; template class BaseHierarchy : public AbstractHierarchy { - template friend class target_lpdf_unconstrained; @@ -298,11 +297,11 @@ class target_lpdf_unconstrained { const DerivedHierarchy &parent; public: - target_lpdf_unconstrained(const DerivedHierarchy &p): parent(p) {} + target_lpdf_unconstrained(const DerivedHierarchy &p) : parent(p) {} - template + template T operator()(const Eigen::Matrix &x) const { - return parent.like->clus_lpdf_from_unconstrained(x) + + return parent.like->cluster_lpdf_from_unconstrained(x) + parent.prior->lpdf_from_unconstrained(x); } }; @@ -310,7 +309,7 @@ class target_lpdf_unconstrained { template void BaseHierarchy::sample_full_cond( bool update_params) { - target_lpdf_unconstrained target(static_cast(*this)); + target_lpdf_unconstrained target(static_cast(*this)); updater->draw(*like, *prior, update_params, target); }; diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 413937dd5..987461cbe 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -3,8 +3,8 @@ #include -#include #include +#include // #include // #include diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 3297114dd..2e294b330 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -3,8 +3,8 @@ #include -#include #include +#include // #include #include // #include diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states.h index 530c90123..52e2d30b5 100644 --- a/src/hierarchies/likelihoods/states.h +++ b/src/hierarchies/likelihoods/states.h @@ -1,27 +1,51 @@ #ifndef BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ #define BAYESMIX_HIERARCHIES_LIKELIHOOD_STATES_H_ -#include #include +#include #include "algorithm_state.pb.h" #include "src/utils/proto_utils.h" namespace State { +template +Eigen::Matrix uni_ls_to_constrained( + Eigen::Matrix in) { + Eigen::Matrix out(2); + out << in(0), stan::math::exp(in(1)); + return out; +} + +template +Eigen::Matrix uni_ls_to_unconstrained( + Eigen::Matrix in) { + Eigen::Matrix out(2); + out << in(0), stan::math::log(in(1)); + return out; +} + +template +T uni_ls_log_det_jac(Eigen::Matrix constrained) { + T out = 0; + stan::math::positive_constrain(stan::math::log(constrained(1)), out); + return out; +} + class UniLS { public: double mean, var; Eigen::VectorXd get_unconstrained() { - Eigen::VectorXd out(2); - out << mean, std::log(var); - return out; + Eigen::VectorXd temp(2); + temp << mean, var; + return uni_ls_to_unconstrained(temp); } void set_from_unconstrained(Eigen::VectorXd in) { - mean = in(0); - var = std::exp(in(1)); + Eigen::VectorXd temp = uni_ls_to_constrained(in); + mean = temp(0); + var = temp(1); } void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { @@ -37,9 +61,9 @@ class UniLS { } double log_det_jac() { - double out = 0; - stan::math::positive_constrain(std::log(var), out); - return out; + Eigen::VectorXd temp(2); + temp << mean, var; + return uni_ls_log_det_jac(temp); } }; diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index bc3108283..407083f00 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -32,7 +32,7 @@ class UniNormLikelihood T var = stan::math::positive_constrain(unconstrained_params(1)); T out = -(data_sum_squares - 2 * mean * data_sum + card * mean * mean) / (2 * var); - out -= card * 0.5 * std::log(stan::math::TWO_PI * var); + out -= card * 0.5 * stan::math::log(stan::math::TWO_PI * var); return out; } diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index bb2e7606e..645af2c96 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -3,7 +3,7 @@ // #include -// #include +// #include // #include // #include @@ -16,15 +16,16 @@ #include "likelihoods/uni_norm_likelihood.h" #include "priors/nig_prior_model.h" // #include "updaters/nnig_updater.h" +#include "updaters/mala_updater.h" #include "updaters/random_walk_updater.h" class NNIGHierarchy : public BaseHierarchy { + NIGPriorModel, MalaUpdater> { public: ~NNIGHierarchy() = default; using BaseHierarchy::BaseHierarchy; + MalaUpdater>::BaseHierarchy; bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNIG; diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 2852a748e..cbf7ba7f6 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -26,7 +26,7 @@ class AbstractPriorModel { //! Usually, some kind of transformation is required from the unconstrained //! parameterization to the actual parameterization. virtual double lpdf_from_unconstrained( - Eigen::VectorXd unconstrained_params) { + Eigen::VectorXd unconstrained_params) const { throw std::runtime_error("lpdf_from_unconstrained() not yet implemented"); } diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index afc21a0d7..924584aed 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -3,11 +3,11 @@ #include -#include #include #include #include #include +#include #include "abstract_prior_model.h" #include "algorithm_state.pb.h" diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h index b338aa1fe..167acce2b 100644 --- a/src/hierarchies/priors/hyperparams.h +++ b/src/hierarchies/priors/hyperparams.h @@ -1,7 +1,7 @@ #ifndef BAYESMIX_HIERARCHIES_PRIORMODEL_HYPERPARAMS_H_ #define BAYESMIX_HIERARCHIES_PRIORMODEL_HYPERPARAMS_H_ -#include +#include namespace Hyperparams { diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index fd6476fc2..a92af1b7b 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -84,13 +84,6 @@ double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } -double NIGPriorModel::lpdf_from_unconstrained( - Eigen::VectorXd unconstrained_params) { - State::UniLS state; - state.set_from_unconstrained(unconstrained_params); - return lpdf(state.get_as_proto()) + state.log_det_jac(); -} - std::shared_ptr NIGPriorModel::sample( bool use_post_hypers) { auto &rng = bayesmix::Rng::Instance().get(); diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index dcdd327df..e12300835 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -3,9 +3,9 @@ // #include -#include #include #include +#include #include // #include "algorithm_state.pb.h" @@ -22,8 +22,25 @@ class NIGPriorModel : public BasePriorModel + T lpdf_from_unconstrained( + const Eigen::Matrix &unconstrained_params) const { + Eigen::Matrix constrained_params = + State::uni_ls_to_constrained(unconstrained_params); + T log_det_jac = State::uni_ls_log_det_jac(constrained_params); + T mean = constrained_params(0); + T var = constrained_params(1); + T lpdf = stan::math::normal_lpdf(mean, hypers.mean, + sqrt(var / hypers.var_scaling)) + + stan::math::inv_gamma_lpdf(var, hypers.shape, hypers.scale); + + return lpdf + log_det_jac; + } + double lpdf_from_unconstrained( - Eigen::VectorXd unconstrained_params) override; + Eigen::VectorXd unconstrained_params) const override { + return this->lpdf_from_unconstrained(unconstrained_params); + } std::shared_ptr sample( bool use_post_hypers) override; diff --git a/src/hierarchies/updaters/mala_updater.h b/src/hierarchies/updaters/mala_updater.h index f407f3b78..3c0d9d012 100644 --- a/src/hierarchies/updaters/mala_updater.h +++ b/src/hierarchies/updaters/mala_updater.h @@ -1,10 +1,10 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_MALA_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_MALA_UPDATER_H_ -#include "metropolis_updater.h" - #include +#include "metropolis_updater.h" + class MalaUpdater : public MetropolisUpdater { protected: double step_size; @@ -14,13 +14,11 @@ class MalaUpdater : public MetropolisUpdater { ~MalaUpdater() = default; MalaUpdater(double step_size) : step_size(step_size) {} - + template Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, AbstractLikelihood &like, - AbstractPriorModel &prior, - F& target_lpdf) { - + AbstractPriorModel &prior, F &target_lpdf) { Eigen::VectorXd noise(curr_state.size()); auto &rng = bayesmix::Rng::Instance().get(); double noise_scale = std::sqrt(2 * step_size); @@ -32,12 +30,11 @@ class MalaUpdater : public MetropolisUpdater { stan::math::gradient(target_lpdf, curr_state, tmp, grad); return curr_state + step_size * grad + noise; } - + template double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, - AbstractLikelihood &like, - AbstractPriorModel &prior, - F& target_lpdf) { + AbstractLikelihood &like, AbstractPriorModel &prior, + F &target_lpdf) { double out; Eigen::VectorXd grad; double tmp; @@ -53,11 +50,10 @@ class MalaUpdater : public MetropolisUpdater { } std::shared_ptr clone() const { - auto out = std::make_shared( - static_cast(*this)); + auto out = + std::make_shared(static_cast(*this)); return out; } - }; #endif diff --git a/src/mixings/abstract_mixing.h b/src/mixings/abstract_mixing.h index 1e6692caf..0f0d50ba5 100644 --- a/src/mixings/abstract_mixing.h +++ b/src/mixings/abstract_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "mixing_id.pb.h" diff --git a/src/mixings/base_mixing.h b/src/mixings/base_mixing.h index 70d60c69a..e9854fd5f 100644 --- a/src/mixings/base_mixing.h +++ b/src/mixings/base_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include "abstract_mixing.h" diff --git a/src/mixings/dirichlet_mixing.h b/src/mixings/dirichlet_mixing.h index ea5645e73..0fab02538 100644 --- a/src/mixings/dirichlet_mixing.h +++ b/src/mixings/dirichlet_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" diff --git a/src/mixings/logit_sb_mixing.cc b/src/mixings/logit_sb_mixing.cc index ed7514f58..a3711086d 100644 --- a/src/mixings/logit_sb_mixing.cc +++ b/src/mixings/logit_sb_mixing.cc @@ -2,10 +2,10 @@ #include -#include #include #include #include +#include #include #include "mixing_prior.pb.h" diff --git a/src/mixings/logit_sb_mixing.h b/src/mixings/logit_sb_mixing.h index 40727cf57..228f4da40 100644 --- a/src/mixings/logit_sb_mixing.h +++ b/src/mixings/logit_sb_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" diff --git a/src/mixings/pityor_mixing.cc b/src/mixings/pityor_mixing.cc index a880470b3..59b357981 100644 --- a/src/mixings/pityor_mixing.cc +++ b/src/mixings/pityor_mixing.cc @@ -2,9 +2,9 @@ #include -#include #include #include +#include #include #include "mixing_prior.pb.h" diff --git a/src/mixings/pityor_mixing.h b/src/mixings/pityor_mixing.h index 2ee93ac87..5ba6cd799 100644 --- a/src/mixings/pityor_mixing.h +++ b/src/mixings/pityor_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" diff --git a/src/mixings/truncated_sb_mixing.cc b/src/mixings/truncated_sb_mixing.cc index efe73e3d5..b826d5582 100644 --- a/src/mixings/truncated_sb_mixing.cc +++ b/src/mixings/truncated_sb_mixing.cc @@ -2,11 +2,11 @@ #include -#include #include #include #include #include +#include #include #include "logit_sb_mixing.h" diff --git a/src/mixings/truncated_sb_mixing.h b/src/mixings/truncated_sb_mixing.h index 27e16c7fd..120a6cc20 100644 --- a/src/mixings/truncated_sb_mixing.h +++ b/src/mixings/truncated_sb_mixing.h @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include "base_mixing.h" diff --git a/src/utils/cluster_utils.cc b/src/utils/cluster_utils.cc index 6f43200b6..89c65c39b 100644 --- a/src/utils/cluster_utils.cc +++ b/src/utils/cluster_utils.cc @@ -1,7 +1,7 @@ #include "cluster_utils.h" -#include #include +#include #include "lib/progressbar/progressbar.h" #include "proto_utils.h" diff --git a/src/utils/cluster_utils.h b/src/utils/cluster_utils.h index c4466db8c..0fe166399 100644 --- a/src/utils/cluster_utils.h +++ b/src/utils/cluster_utils.h @@ -1,7 +1,7 @@ #ifndef BAYESMIX_UTILS_CLUSTER_UTILS_H_ #define BAYESMIX_UTILS_CLUSTER_UTILS_H_ -#include +#include //! This file includes some utilities for cluster estimation. These functions //! only use Eigen ojects. diff --git a/src/utils/distributions.cc b/src/utils/distributions.cc index 52bcdf665..13efb6e48 100644 --- a/src/utils/distributions.cc +++ b/src/utils/distributions.cc @@ -1,8 +1,8 @@ #include "distributions.h" -#include #include #include +#include #include "src/utils/proto_utils.h" diff --git a/src/utils/distributions.h b/src/utils/distributions.h index 33ae2467e..0af2b0fc3 100644 --- a/src/utils/distributions.h +++ b/src/utils/distributions.h @@ -1,8 +1,8 @@ #ifndef BAYESMIX_UTILS_DISTRIBUTIONS_H_ #define BAYESMIX_UTILS_DISTRIBUTIONS_H_ -#include #include +#include #include #include "algorithm_state.pb.h" diff --git a/src/utils/eigen_utils.h b/src/utils/eigen_utils.h index d2b55460f..a178b3812 100644 --- a/src/utils/eigen_utils.h +++ b/src/utils/eigen_utils.h @@ -1,4 +1,4 @@ -#include +#include #include //! This file implements a few methods to manipulate groups of matrices, mainly diff --git a/src/utils/io_utils.cc b/src/utils/io_utils.cc index 3c4bf59c0..6b22405b3 100644 --- a/src/utils/io_utils.cc +++ b/src/utils/io_utils.cc @@ -1,7 +1,7 @@ #include "io_utils.h" -#include #include +#include Eigen::MatrixXd bayesmix::read_eigen_matrix(const std::string &filename, const char delim /* = ','*/) { diff --git a/src/utils/io_utils.h b/src/utils/io_utils.h index 69182631a..b8c68db3d 100644 --- a/src/utils/io_utils.h +++ b/src/utils/io_utils.h @@ -1,7 +1,7 @@ #ifndef BAYESMIX_UTILS_IO_UTILS_H_ #define BAYESMIX_UTILS_IO_UTILS_H_ -#include +#include //! This file implements basic input-output utilities for Eigen matrices from //! and to text files. diff --git a/src/utils/proto_utils.cc b/src/utils/proto_utils.cc index c1bab2c92..5178ca99f 100644 --- a/src/utils/proto_utils.cc +++ b/src/utils/proto_utils.cc @@ -3,8 +3,8 @@ #include #include -#include #include +#include #include "matrix.pb.h" diff --git a/src/utils/proto_utils.h b/src/utils/proto_utils.h index ca2a53c4b..f662fb2e4 100644 --- a/src/utils/proto_utils.h +++ b/src/utils/proto_utils.h @@ -1,7 +1,7 @@ #ifndef BAYESMIX_UTILS_PROTO_UTILS_H_ #define BAYESMIX_UTILS_PROTO_UTILS_H_ -#include +#include #include "matrix.pb.h" diff --git a/test/collectors.cc b/test/collectors.cc index 38da70caa..7f03e96c5 100644 --- a/test/collectors.cc +++ b/test/collectors.cc @@ -1,6 +1,6 @@ #include -#include +#include #include #include "matrix.pb.h" diff --git a/test/distributions.cc b/test/distributions.cc index acfbe7e31..1e1468b15 100644 --- a/test/distributions.cc +++ b/test/distributions.cc @@ -2,8 +2,8 @@ #include -#include #include +#include #include #include "src/hierarchies/likelihoods/states.h" diff --git a/test/eigen_utils.cc b/test/eigen_utils.cc index 5fd5b23b7..c02f11983 100644 --- a/test/eigen_utils.cc +++ b/test/eigen_utils.cc @@ -2,7 +2,7 @@ #include -#include +#include #include TEST(vstack, 1) { diff --git a/test/hierarchies.cc b/test/hierarchies.cc index a9b98d79b..d0a28081b 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include "algorithm_state.pb.h" #include "ls_state.pb.h" diff --git a/test/likelihoods.cc b/test/likelihoods.cc index 8335d653d..46ecf92db 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -1,8 +1,8 @@ #include -#include #include #include +#include #include "algorithm_state.pb.h" #include "ls_state.pb.h" diff --git a/test/logit_sb.cc b/test/logit_sb.cc index 0033e6bf6..85405dc00 100644 --- a/test/logit_sb.cc +++ b/test/logit_sb.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include #include "src/hierarchies/abstract_hierarchy.h" diff --git a/test/lpdf.cc b/test/lpdf.cc index f868aba05..90512e20c 100644 --- a/test/lpdf.cc +++ b/test/lpdf.cc @@ -1,8 +1,8 @@ #include -#include #include // lgamma, lmgamma #include +#include #include "algorithm_state.pb.h" // #include "src/hierarchies/lin_reg_uni_hierarchy.h" diff --git a/test/mfm_mixing.cc b/test/mfm_mixing.cc index 16fb815c1..601dbf0e8 100644 --- a/test/mfm_mixing.cc +++ b/test/mfm_mixing.cc @@ -1,9 +1,9 @@ #include -#include #include +#include TEST(mfm_mixing, mfm_mixing_test) { ASSERT_EQ(1, 1); ASSERT_GT(0, 1); -} \ No newline at end of file +} diff --git a/test/prior_models.cc b/test/prior_models.cc index 0ef122c98..b64b8832d 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -1,8 +1,8 @@ #include -#include #include #include +#include #include "algorithm_state.pb.h" #include "hierarchy_prior.pb.h" diff --git a/test/proto_utils.cc b/test/proto_utils.cc index baf904824..1ca53d155 100644 --- a/test/proto_utils.cc +++ b/test/proto_utils.cc @@ -2,7 +2,7 @@ #include -#include +#include #include "matrix.pb.h" diff --git a/test/rng.cc b/test/rng.cc index 79c29572a..47de9259a 100644 --- a/test/rng.cc +++ b/test/rng.cc @@ -2,7 +2,7 @@ #include -#include +#include #include "src/hierarchies/nnig_hierarchy.h" #include "src/utils/distributions.h" diff --git a/test/semi_hdp.cc b/test/semi_hdp.cc index 830933d9c..df55b6830 100644 --- a/test/semi_hdp.cc +++ b/test/semi_hdp.cc @@ -1,7 +1,7 @@ #include -#include #include +#include #include #include "semihdp.pb.h" From e9da1e217efd1345ded7d14d1b70f708519486d2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 27 Jan 2022 20:57:49 +0100 Subject: [PATCH 111/317] Started documenting (ONGOING) --- src/hierarchies/base_hierarchy.h | 275 +++++++------------------------ 1 file changed, 58 insertions(+), 217 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 734aa227a..61c0940d1 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -22,19 +22,24 @@ //! class for further information). It includes class members and some more //! functions which could not be implemented in the non-templatized abstract //! class. -//! See, for instance, `ConjugateHierarchy` and `NNIGHierarchy` to better -//! understand the CRTP patterns. +//! See, for instance, `NNIGHierarchy` to better understand the CRTP patterns. //! @tparam Derived Name of the implemented derived class -//! @tparam State Class name of the container for state values -//! @tparam Hyperparams Class name of the container for hyperprior parameters -//! @tparam Prior Class name of the container for prior parameters +//! @tparam Likelihood Class name of the likelihood model for the hierarchy +//! @tparam PriorModel Class name of the prior model for the hierarchy +//! @tparam Updater Class name for the update algorithm used for posterior sampling template class BaseHierarchy : public AbstractHierarchy { protected: + + //! Container for the likelihood of the hierarchy std::shared_ptr like = std::make_shared(); + + //! Container for the prior model of the hierarchy std::shared_ptr prior = std::make_shared(); + + //! Container for the update algorithm adopted std::shared_ptr updater = std::make_shared(); public: @@ -46,6 +51,7 @@ class BaseHierarchy : public AbstractHierarchy { void set_prior(std::shared_ptr prior_) { prior = prior_; }; void set_updater(std::shared_ptr updater_) { updater = updater_; }; + //! Returns an independent, data-less copy of this object std::shared_ptr clone() const override { // Create copy of the hierarchy auto out = std::make_shared(static_cast(*this)); @@ -55,16 +61,22 @@ class BaseHierarchy : public AbstractHierarchy { return out; }; + // NOT SURE THIS IS CORRECT, MAYBE OVERRIDE GET_LIKE_LPDF? OR THIS IS EVEN UNNECESSARY double like_lpdf(const Eigen::RowVectorXd &datum) const override { return like->lpdf(datum); } + //! Evaluates the log-likelihood of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const override { return like->lpdf_grid(data, covariates); }; + // ADD EXCEPTION HANDLING double get_marg_lpdf( const HyperParams ¶ms, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { @@ -75,12 +87,14 @@ class BaseHierarchy : public AbstractHierarchy { } } + // ADD EXCEPTION HANDLING double prior_pred_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const override { return get_marg_lpdf(prior->get_hypers(), datum, covariate); } + // ADD EXCEPTION HANDLING Eigen::VectorXd prior_pred_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { @@ -107,12 +121,14 @@ class BaseHierarchy : public AbstractHierarchy { return lpdf; } + // ADD EXCEPTION HANDLING double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const override { return get_marg_lpdf(prior->get_posterior_hypers(), datum, covariate); } + // ADD EXCEPTION HANDLING Eigen::VectorXd conditional_pred_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { @@ -139,18 +155,18 @@ class BaseHierarchy : public AbstractHierarchy { return lpdf; } + //! Generates new state values from the centering prior distribution void sample_prior() override { - // int card = like->get_card(); like->set_state_from_proto(*prior->sample(false), false); - // like->set_card(card); }; + //! Generates new state values from the centering posterior distribution + //! @param update_params Save posterior hypers after the computation? void sample_full_cond(bool update_params = false) override { - // int card = like->get_card(); updater->draw(*like, *prior, update_params); - // like->set_card(card); }; + //! Overloaded version of sample_full_cond(bool), mainly used for debugging void sample_full_cond( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override { @@ -183,24 +199,31 @@ class BaseHierarchy : public AbstractHierarchy { prior->update_hypers(states); }; + //! Returns the class of the current state auto get_state() const -> decltype(like->get_state()) { return like->get_state(); }; + //! Returns the current cardinality of the cluster int get_card() const override { return like->get_card(); }; + //! Returns the logarithm of the current cardinality of the cluster double get_log_card() const override { return like->get_log_card(); }; + //! Returns the indexes of data points belonging to this cluster std::set get_data_idx() const override { return like->get_data_idx(); }; + //! Returns a pointer to the Protobuf message of the prior of this cluster google::protobuf::Message *get_mutable_prior() { return prior->get_mutable_prior(); }; + //! Writes current state to a Protobuf message by pointer void write_state_to_proto(google::protobuf::Message *out) const override { like->write_state_to_proto(out); }; + //! Writes current values of the hyperparameters to a Protobuf message by pointer void write_hypers_to_proto(google::protobuf::Message *out) const override { prior->write_hypers_to_proto(out); }; @@ -214,6 +237,7 @@ class BaseHierarchy : public AbstractHierarchy { prior->set_hypers_from_proto(state_); }; + //! Adds a datum and its index to the hierarchy void add_datum( const int id, const Eigen::RowVectorXd &datum, const bool update_params = false, @@ -222,6 +246,7 @@ class BaseHierarchy : public AbstractHierarchy { if (update_params) updater->compute_posterior_hypers(*like, *prior); }; + //! Removes a datum and its index from the hierarchy void remove_datum( const int id, const Eigen::RowVectorXd &datum, const bool update_params = false, @@ -230,6 +255,7 @@ class BaseHierarchy : public AbstractHierarchy { if (update_params) updater->compute_posterior_hypers(*like, *prior); }; + //! Main function that initializes members to appropriate values void initialize() override { prior->initialize(); if (is_conjugate()) @@ -247,8 +273,10 @@ class BaseHierarchy : public AbstractHierarchy { protected: + //! Initializes state parameters to appropriate values virtual void initialize_state() = 0; + // ADD EXEPTION HANDLING virtual double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum) const { if (!is_conjugate()) { @@ -259,6 +287,7 @@ class BaseHierarchy : public AbstractHierarchy { } } + // ADD EXEPTION HANDLING virtual double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { @@ -271,163 +300,83 @@ class BaseHierarchy : public AbstractHierarchy { } }; -// //! Returns an independent, data-less copy of this object -// virtual std::shared_ptr clone() const override { -// auto out = std::make_shared(static_cast(*this)); out->clear_data(); out->clear_summary_statistics(); return -// out; -// } - -// //! Evaluates the log-likelihood of data in a grid of points -// //! @param data Grid of points (by row) which are to be evaluated -// //! @param covariates (Optional) covariate vectors associated to data -// //! @return The evaluation of the lpdf -// virtual Eigen::VectorXd like_lpdf_grid( -// const Eigen::MatrixXd &data, -// const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, -// 0)) const -// override; - -// //! Generates new state values from the centering prior distribution -// void sample_prior() override { -// state = static_cast(this)->draw(*hypers); -// } - -// //! Overloaded version of sample_full_cond(bool), mainly used for -// debugging virtual void sample_full_cond( -// const Eigen::MatrixXd &data, -// const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override; - -// //! Returns the current cardinality of the cluster -// int get_card() const override { return card; } - -// //! Returns the logarithm of the current cardinality of the cluster -// double get_log_card() const override { return log_card; } - -// //! Returns the indexes of data points belonging to this cluster -// std::set get_data_idx() const override { return cluster_data_idx; } - -// //! Returns a pointer to the Protobuf message of the prior of this cluster -// virtual google::protobuf::Message *get_mutable_prior() override { -// if (prior == nullptr) { -// create_empty_prior(); -// } -// return prior.get(); -// } - -// //! Writes current state to a Protobuf message by pointer -// void write_state_to_proto(google::protobuf::Message *out) const override; - -// //! Writes current values of the hyperparameters to a Protobuf message by -// //! pointer -// void write_hypers_to_proto(google::protobuf::Message *out) const override; +// TODO: Move definitions outside the class to improve code cleaness +// TODO: Move this docs in the right place -// //! Returns the struct of the current state -// State get_state() const { return state; } - -// //! Returns the struct of the current prior hyperparameters +//! Returns the struct of the current prior hyperparameters // Hyperparams get_hypers() const { return *hypers; } -// //! Returns the struct of the current posterior hyperparameters +//! Returns the struct of the current posterior hyperparameters // Hyperparams get_posterior_hypers() const { return posterior_hypers; } -// //! Adds a datum and its index to the hierarchy -// void add_datum( -// const int id, const Eigen::RowVectorXd &datum, -// const bool update_params = false, -// const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; - -// //! Removes a datum and its index from the hierarchy -// void remove_datum( -// const int id, const Eigen::RowVectorXd &datum, -// const bool update_params = false, -// const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; - -// //! Main function that initializes members to appropriate values -// void initialize() override { -// hypers = std::make_shared(); -// check_prior_is_set(); -// initialize_hypers(); -// initialize_state(); -// posterior_hypers = *hypers; -// clear_data(); -// clear_summary_statistics(); -// } -// protected: -// //! Raises an error if the prior pointer is not initialized +//! Raises an error if the prior pointer is not initialized // void check_prior_is_set() const { // if (prior == nullptr) { // throw std::invalid_argument("Hierarchy prior was not provided"); // } // } -// //! Re-initializes the prior of the hierarchy to a newly created object +//! Re-initializes the prior of the hierarchy to a newly created object // void create_empty_prior() { prior.reset(new Prior); } -// //! Sets the cardinality of the cluster +//! Sets the cardinality of the cluster // void set_card(const int card_) { // card = card_; // log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); // } -// //! Writes current state to a Protobuf message and return a shared_ptr -// //! New hierarchies have to first modify the field 'oneof val' in the -// //! AlgoritmState::ClusterState message by adding the appropriate type +//! Writes current state to a Protobuf message and return a shared_ptr +//! New hierarchies have to first modify the field 'oneof val' in the +//! AlgoritmState::ClusterState message by adding the appropriate type // virtual std::shared_ptr // get_state_proto() const = 0; -// //! Initializes state parameters to appropriate values -// virtual void initialize_state() = 0; -// //! Writes current value of hyperparameters to a Protobuf message and -// //! return a shared_ptr. -// //! New hierarchies have to first modify the field 'oneof val' in the -// //! AlgoritmState::HierarchyHypers message by adding the appropriate type +//! Writes current value of hyperparameters to a Protobuf message and +//! return a shared_ptr. +//! New hierarchies have to first modify the field 'oneof val' in the +//! AlgoritmState::HierarchyHypers message by adding the appropriate type // virtual std::shared_ptr // get_hypers_proto() const = 0; -// //! Initializes hierarchy hyperparameters to appropriate values +//! Initializes hierarchy hyperparameters to appropriate values // virtual void initialize_hypers() = 0; -// //! Resets cardinality and indexes of data in this cluster +//! Resets cardinality and indexes of data in this cluster // void clear_data() { // set_card(0); // cluster_data_idx = std::set(); // } -// virtual void clear_summary_statistics() = 0; - -// //! Down-casts the given generic proto message to a ClusterState proto +//! Down-casts the given generic proto message to a ClusterState proto // bayesmix::AlgorithmState::ClusterState *downcast_state( // google::protobuf::Message *state_) const { // return google::protobuf::internal::down_cast< // bayesmix::AlgorithmState::ClusterState *>(state_); // } -// //! Down-casts the given generic proto message to a ClusterState proto +//! Down-casts the given generic proto message to a ClusterState proto // const bayesmix::AlgorithmState::ClusterState &downcast_state( // const google::protobuf::Message &state_) const { // return google::protobuf::internal::down_cast< // const bayesmix::AlgorithmState::ClusterState &>(state_); // } -// //! Down-casts the given generic proto message to a HierarchyHypers proto +//! Down-casts the given generic proto message to a HierarchyHypers proto // bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( // google::protobuf::Message *state_) const { // return google::protobuf::internal::down_cast< // bayesmix::AlgorithmState::HierarchyHypers *>(state_); // } -// //! Down-casts the given generic proto message to a HierarchyHypers proto +//! Down-casts the given generic proto message to a HierarchyHypers proto // const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( // const google::protobuf::Message &state_) const { // return google::protobuf::internal::down_cast< // const bayesmix::AlgorithmState::HierarchyHypers &>(state_); // } -// //! Container for state values -// State state; // //! Container for prior hyperparameters values // std::shared_ptr hypers; @@ -448,112 +397,4 @@ class BaseHierarchy : public AbstractHierarchy { // double log_card = stan::math::NEGATIVE_INFTY; // }; -// template void BaseHierarchy::add_datum( -// const int id, const Eigen::RowVectorXd &datum, -// const bool update_params /*= false*/, -// const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) { -// assert(cluster_data_idx.find(id) == cluster_data_idx.end()); -// card += 1; -// log_card = std::log(card); -// static_cast(this)->update_ss(datum, covariate, true); -// cluster_data_idx.insert(id); -// if (update_params) { -// static_cast(this)->save_posterior_hypers(); -// } -// } - -// template void BaseHierarchy::remove_datum( -// const int id, const Eigen::RowVectorXd &datum, -// const bool update_params /*= false*/, -// const Eigen::RowVectorXd &covariate /* = Eigen::RowVectorXd(0)*/) { -// static_cast(this)->update_ss(datum, covariate, false); -// set_card(card - 1); -// auto it = cluster_data_idx.find(id); -// assert(it != cluster_data_idx.end()); -// cluster_data_idx.erase(it); -// if (update_params) { -// static_cast(this)->save_posterior_hypers(); -// } -// } - -// template void BaseHierarchy::write_state_to_proto( -// google::protobuf::Message *out) const { -// std::shared_ptr state_ = -// get_state_proto(); -// auto *out_cast = downcast_state(out); -// out_cast->CopyFrom(*state_.get()); -// out_cast->set_cardinality(card); -// } - -// template void BaseHierarchy::write_hypers_to_proto( -// google::protobuf::Message *out) const { -// std::shared_ptr hypers_ = -// get_hypers_proto(); -// auto *out_cast = downcast_hypers(out); -// out_cast->CopyFrom(*hypers_.get()); -// } - -// template Eigen::VectorXd BaseHierarchy::like_lpdf_grid( -// const Eigen::MatrixXd &data, -// const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { -// Eigen::VectorXd lpdf(data.rows()); -// if (covariates.cols() == 0) { -// // Pass null value as covariate -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->get_like_lpdf( -// data.row(i), Eigen::RowVectorXd(0)); -// } -// } else if (covariates.rows() == 1) { -// // Use unique covariate -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->get_like_lpdf( -// data.row(i), covariates.row(0)); -// } -// } else { -// // Use different covariates -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->get_like_lpdf( -// data.row(i), covariates.row(i)); -// } -// } -// return lpdf; -// } - -// template void BaseHierarchy::sample_full_cond( -// const Eigen::MatrixXd &data, -// const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { -// clear_data(); -// clear_summary_statistics(); -// if (covariates.cols() == 0) { -// // Pass null value as covariate -// for (int i = 0; i < data.rows(); i++) { -// static_cast(this)->add_datum(i, data.row(i), false, -// Eigen::RowVectorXd(0)); -// } -// } else if (covariates.rows() == 1) { -// // Use unique covariate -// for (int i = 0; i < data.rows(); i++) { -// static_cast(this)->add_datum(i, data.row(i), false, -// covariates.row(0)); -// } -// } else { -// // Use different covariates -// for (int i = 0; i < data.rows(); i++) { -// static_cast(this)->add_datum(i, data.row(i), false, -// covariates.row(i)); -// } -// } -// static_cast(this)->sample_full_cond(true); -// } - #endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ From 112f95e895e6c286f2aa093fda1f02faecab6120 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Fri, 28 Jan 2022 09:26:22 +0100 Subject: [PATCH 112/317] comment fixes --- CMakeLists.txt | 2 - src/hierarchies/abstract_hierarchy.h | 19 +++ src/hierarchies/base_hierarchy.h | 117 ++---------------- .../likelihoods/abstract_likelihood.h | 7 ++ src/hierarchies/likelihoods/base_likelihood.h | 16 +++ .../likelihoods/uni_norm_likelihood.h | 6 - src/hierarchies/priors/abstract_prior_model.h | 7 ++ src/hierarchies/priors/base_prior_model.h | 14 +++ src/hierarchies/priors/nig_prior_model.h | 5 - src/hierarchies/updaters/CMakeLists.txt | 8 +- src/hierarchies/updaters/abstract_updater.h | 11 +- src/hierarchies/updaters/mala_updater.h | 7 +- src/hierarchies/updaters/metropolis_updater.h | 18 +-- test/CMakeLists.txt | 20 +-- test/mfm_mixing.cc | 9 -- 15 files changed, 103 insertions(+), 163 deletions(-) delete mode 100644 test/mfm_mixing.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d600cce7..641474329 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -158,8 +158,6 @@ if (BUILD_RUN) target_compile_options(run_mcmc PUBLIC ${COMPILE_OPTIONS}) endif() -add_subdirectory(test) - if (NOT DISABLE_TESTS) add_subdirectory(test) endif() diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 963ac12b8..279f1d8b1 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -49,12 +49,17 @@ //! instance the proto/ls_state.proto and proto/hierarchy_prior.proto files) //! and their relative class methods. +class target_lpdf_unconstrained; + class AbstractHierarchy { public: virtual void set_likelihood(std::shared_ptr like_) = 0; virtual void set_prior(std::shared_ptr prior_) = 0; virtual void set_updater(std::shared_ptr updater_) = 0; + virtual std::shared_ptr get_likelihood() = 0; + virtual std::shared_ptr get_prior() = 0; + virtual ~AbstractHierarchy() = default; //! Returns an independent, data-less copy of this object @@ -284,4 +289,18 @@ class AbstractHierarchy { } }; +class target_lpdf_unconstrained { + protected: + AbstractHierarchy *parent; + + public: + target_lpdf_unconstrained(AbstractHierarchy *p) : parent(p) {} + + template + T operator()(const Eigen::Matrix &x) const { + return parent->get_likelihood()->cluster_lpdf_from_unconstrained(x) + + parent->get_prior()->lpdf_from_unconstrained(x); + } +}; + #endif // BAYESMIX_HIERARCHIES_ABSTRACT_HIERARCHY_H_ diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index e7fe80785..4f9733a1e 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -12,11 +12,9 @@ #include "abstract_hierarchy.h" #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" +#include "src/hierarchies/updaters/target_lpdf_unconstrained.h" #include "src/utils/rng.h" -template -class target_lpdf_unconstrained; - //! Base template class for a hierarchy object. //! This class is a templatized version of, and derived from, the @@ -35,9 +33,6 @@ class target_lpdf_unconstrained; template class BaseHierarchy : public AbstractHierarchy { - template - friend class target_lpdf_unconstrained; - protected: std::shared_ptr like = std::make_shared(); std::shared_ptr prior = std::make_shared(); @@ -72,6 +67,11 @@ class BaseHierarchy : public AbstractHierarchy { updater = std::static_pointer_cast(updater_); }; + std::shared_ptr get_likelihood() override { + return like; + } + std::shared_ptr get_prior() override { return prior; } + std::shared_ptr clone() const override { // Create copy of the hierarchy auto out = std::make_shared(static_cast(*this)); @@ -171,7 +171,10 @@ class BaseHierarchy : public AbstractHierarchy { // like->set_card(card); }; - void sample_full_cond(bool update_params = false) override; + void sample_full_cond(bool update_params = false) override { + target_lpdf_unconstrained target(this); + updater->draw(*like, *prior, update_params, target); + }; void sample_full_cond( const Eigen::MatrixXd &data, @@ -291,104 +294,4 @@ class BaseHierarchy : public AbstractHierarchy { } }; -template -class target_lpdf_unconstrained { - protected: - const DerivedHierarchy &parent; - - public: - target_lpdf_unconstrained(const DerivedHierarchy &p) : parent(p) {} - - template - T operator()(const Eigen::Matrix &x) const { - return parent.like->cluster_lpdf_from_unconstrained(x) + - parent.prior->lpdf_from_unconstrained(x); - } -}; - -template -void BaseHierarchy::sample_full_cond( - bool update_params) { - target_lpdf_unconstrained target(static_cast(*this)); - updater->draw(*like, *prior, update_params, target); -}; - - // template void BaseHierarchy::write_state_to_proto( - // google::protobuf::Message *out) const { - // std::shared_ptr state_ = - // get_state_proto(); - // auto *out_cast = downcast_state(out); - // out_cast->CopyFrom(*state_.get()); - // out_cast->set_cardinality(card); - // } - - // template void BaseHierarchy::write_hypers_to_proto( - // google::protobuf::Message *out) const { - // std::shared_ptr hypers_ = - // get_hypers_proto(); - // auto *out_cast = downcast_hypers(out); - // out_cast->CopyFrom(*hypers_.get()); - // } - - // template Eigen::VectorXd BaseHierarchy::like_lpdf_grid( - // const Eigen::MatrixXd &data, - // const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - // Eigen::VectorXd lpdf(data.rows()); - // if (covariates.cols() == 0) { - // // Pass null value as covariate - // for (int i = 0; i < data.rows(); i++) { - // lpdf(i) = static_cast(this)->get_like_lpdf( - // data.row(i), Eigen::RowVectorXd(0)); - // } - // } else if (covariates.rows() == 1) { - // // Use unique covariate - // for (int i = 0; i < data.rows(); i++) { - // lpdf(i) = static_cast(this)->get_like_lpdf( - // data.row(i), covariates.row(0)); - // } - // } else { - // // Use different covariates - // for (int i = 0; i < data.rows(); i++) { - // lpdf(i) = static_cast(this)->get_like_lpdf( - // data.row(i), covariates.row(i)); - // } - // } - // return lpdf; - // } - - // template void BaseHierarchy::sample_full_cond( - // const Eigen::MatrixXd &data, - // const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { - // clear_data(); - // clear_summary_statistics(); - // if (covariates.cols() == 0) { - // // Pass null value as covariate - // for (int i = 0; i < data.rows(); i++) { - // static_cast(this)->add_datum(i, data.row(i), false, - // Eigen::RowVectorXd(0)); - // } - // } else if (covariates.rows() == 1) { - // // Use unique covariate - // for (int i = 0; i < data.rows(); i++) { - // static_cast(this)->add_datum(i, data.row(i), false, - // covariates.row(0)); - // } - // } else { - // // Use different covariates - // for (int i = 0; i < data.rows(); i++) { - // static_cast(this)->add_datum(i, data.row(i), false, - // covariates.row(i)); - // } - // } - // static_cast(this)->sample_full_cond(true); - // } - #endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 987461cbe..c19bf882c 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -43,6 +43,13 @@ class AbstractLikelihood { "cluster_lpdf_from_unconstrained() not yet implemented"); } + virtual stan::math::var cluster_lpdf_from_unconstrained( + Eigen::Matrix unconstrained_params) + const { + throw std::runtime_error( + "cluster_lpdf_from_unconstrained() not yet implemented"); + } + virtual Eigen::VectorXd lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const = 0; diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 2e294b330..81ed354fc 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -25,6 +25,22 @@ class BaseLikelihood : public AbstractLikelihood { return out; } + // The unconstrained parameters are mean and log(var) + double cluster_lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) const override { + return static_cast(*this) + .template cluster_lpdf_from_unconstrained( + unconstrained_params); + } + + stan::math::var cluster_lpdf_from_unconstrained( + Eigen::Matrix unconstrained_params) + const override { + return static_cast(*this) + .template cluster_lpdf_from_unconstrained( + unconstrained_params); + } + virtual Eigen::VectorXd lpdf_grid(const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const override; diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 407083f00..68c382fd4 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -36,12 +36,6 @@ class UniNormLikelihood return out; } - // The unconstrained parameters are mean and log(var) - double cluster_lpdf_from_unconstrained( - Eigen::VectorXd unconstrained_params) const override { - return this->cluster_lpdf_from_unconstrained(unconstrained_params); - } - protected: std::shared_ptr get_state_proto() const override; diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index cbf7ba7f6..357ea5516 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -30,6 +30,13 @@ class AbstractPriorModel { throw std::runtime_error("lpdf_from_unconstrained() not yet implemented"); } + virtual stan::math::var lpdf_from_unconstrained( + Eigen::Matrix unconstrained_params) + const { + throw std::runtime_error( + "cluster_lpdf_from_unconstrained() not yet implemented"); + } + // Da pensare, come restituisco lo stato? magari un pointer? Oppure delego virtual std::shared_ptr sample( bool use_post_hypers) = 0; diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 924584aed..faf9be373 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -21,6 +21,20 @@ class BasePriorModel : public AbstractPriorModel { ~BasePriorModel() = default; + double lpdf_from_unconstrained( + Eigen::VectorXd unconstrained_params) const override { + return static_cast(*this) + .template lpdf_from_unconstrained(unconstrained_params); + } + + stan::math::var lpdf_from_unconstrained( + Eigen::Matrix unconstrained_params) + const override { + return static_cast(*this) + .template lpdf_from_unconstrained( + unconstrained_params); + } + virtual std::shared_ptr clone() const override; virtual google::protobuf::Message *get_mutable_prior() override; diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index e12300835..b0f491e9f 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -37,11 +37,6 @@ class NIGPriorModel : public BasePriorModellpdf_from_unconstrained(unconstrained_params); - } - std::shared_ptr sample( bool use_post_hypers) override; diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index b366686f9..3d3481f58 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,8 +1,10 @@ target_sources(bayesmix PUBLIC abstract_updater.h - #conjugate_updater.h + conjugate_updater.h + mala_updater.h + metropolis_updater.h # nnig_updater.h # nnig_updater.cc - metropolis_updater.h - random_walk_updater.h + # random_walk_updater.h + target_lpdf_unconstrained.h ) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 9a3a78d57..35ac6cde4 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -1,15 +1,22 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ +#include "src/hierarchies/abstract_hierarchy.h" #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" +class target_lpdf_unconstrained; + class AbstractUpdater { public: virtual ~AbstractUpdater() = default; + virtual bool is_conjugate() const { return false; }; - // virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, - // bool update_params) = 0; + + virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, + bool update_params, + target_lpdf_unconstrained &target_lpdf) = 0; + virtual void compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) { throw std::runtime_error("compute_posterior_hypers not implemented"); diff --git a/src/hierarchies/updaters/mala_updater.h b/src/hierarchies/updaters/mala_updater.h index 3c0d9d012..a619780f2 100644 --- a/src/hierarchies/updaters/mala_updater.h +++ b/src/hierarchies/updaters/mala_updater.h @@ -15,10 +15,10 @@ class MalaUpdater : public MetropolisUpdater { MalaUpdater(double step_size) : step_size(step_size) {} - template Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, AbstractLikelihood &like, - AbstractPriorModel &prior, F &target_lpdf) { + AbstractPriorModel &prior, + target_lpdf_unconstrained &target_lpdf) { Eigen::VectorXd noise(curr_state.size()); auto &rng = bayesmix::Rng::Instance().get(); double noise_scale = std::sqrt(2 * step_size); @@ -31,10 +31,9 @@ class MalaUpdater : public MetropolisUpdater { return curr_state + step_size * grad + noise; } - template double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, AbstractLikelihood &like, AbstractPriorModel &prior, - F &target_lpdf) { + target_lpdf_unconstrained &target_lpdf) { double out; Eigen::VectorXd grad; double tmp; diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h index 70f91b9f4..46cc1bc50 100644 --- a/src/hierarchies/updaters/metropolis_updater.h +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -2,13 +2,14 @@ #define BAYESMIX_HIERARCHIES_UPDATERS_METROPOLIS_UPDATER_H_ #include "abstract_updater.h" +#include "target_lpdf_unconstrained.h" template class MetropolisUpdater : public AbstractUpdater { public: - template void draw(AbstractLikelihood &like, AbstractPriorModel &prior, - bool update_params, F &target_lpdf) { + bool update_params, + target_lpdf_unconstrained &target_lpdf) override { Eigen::VectorXd curr_state = like.get_unconstrained_state(); Eigen::VectorXd prop_state = static_cast(this)->sample_proposal( @@ -26,19 +27,6 @@ class MetropolisUpdater : public AbstractUpdater { like.set_state_from_unconstrained(prop_state); } } - - // template - // virtual Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, - // AbstractLikelihood &like, - // AbstractPriorModel &prior, - // F& target_lpdf) = 0; - - // template - // virtual double proposal_lpdf(Eigen::VectorXd prop_state, - // Eigen::VectorXd curr_state, - // AbstractLikelihood &like, - // AbstractPriorModel &prior, - // F& target_lpdf) = 0; }; #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d58fb0afb..8d147234c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,20 +16,20 @@ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) add_executable(test_bayesmix $ - # write_proto.cc - # proto_utils.cc + write_proto.cc + proto_utils.cc likelihoods.cc prior_models.cc - # hierarchies.cc - # lpdf.cc + hierarchies.cc + lpdf.cc # priors.cc // OLD, USEREI prior_models.cc - # eigen_utils.cc + eigen_utils.cc distributions.cc - # semi_hdp.cc - # collectors.cc - # runtime.cc - # rng.cc - # logit_sb.cc + semi_hdp.cc + collectors.cc + runtime.cc + rng.cc + logit_sb.cc gradient.cc ) diff --git a/test/mfm_mixing.cc b/test/mfm_mixing.cc deleted file mode 100644 index 601dbf0e8..000000000 --- a/test/mfm_mixing.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include - -#include -#include - -TEST(mfm_mixing, mfm_mixing_test) { - ASSERT_EQ(1, 1); - ASSERT_GT(0, 1); -} From ac80834b817d88f2190b66f5e5b4f308c2073e37 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Fri, 28 Jan 2022 14:58:13 +0100 Subject: [PATCH 113/317] improvements --- src/hierarchies/abstract_hierarchy.h | 16 ---------------- src/hierarchies/base_hierarchy.h | 5 ++--- src/hierarchies/updaters/CMakeLists.txt | 1 + src/hierarchies/updaters/abstract_updater.h | 6 ++---- src/hierarchies/updaters/metropolis_updater.h | 4 ++-- 5 files changed, 7 insertions(+), 25 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 279f1d8b1..33d63b38a 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -49,8 +49,6 @@ //! instance the proto/ls_state.proto and proto/hierarchy_prior.proto files) //! and their relative class methods. -class target_lpdf_unconstrained; - class AbstractHierarchy { public: virtual void set_likelihood(std::shared_ptr like_) = 0; @@ -289,18 +287,4 @@ class AbstractHierarchy { } }; -class target_lpdf_unconstrained { - protected: - AbstractHierarchy *parent; - - public: - target_lpdf_unconstrained(AbstractHierarchy *p) : parent(p) {} - - template - T operator()(const Eigen::Matrix &x) const { - return parent->get_likelihood()->cluster_lpdf_from_unconstrained(x) + - parent->get_prior()->lpdf_from_unconstrained(x); - } -}; - #endif // BAYESMIX_HIERARCHIES_ABSTRACT_HIERARCHY_H_ diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 4f9733a1e..2bcf1acdd 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -12,8 +12,8 @@ #include "abstract_hierarchy.h" #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" -#include "src/hierarchies/updaters/target_lpdf_unconstrained.h" #include "src/utils/rng.h" +#include "updaters/target_lpdf_unconstrained.h" //! Base template class for a hierarchy object. @@ -172,8 +172,7 @@ class BaseHierarchy : public AbstractHierarchy { }; void sample_full_cond(bool update_params = false) override { - target_lpdf_unconstrained target(this); - updater->draw(*like, *prior, update_params, target); + updater->draw(*like, *prior, update_params); }; void sample_full_cond( diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 3d3481f58..e6f1faf5d 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -7,4 +7,5 @@ target_sources(bayesmix PUBLIC # nnig_updater.cc # random_walk_updater.h target_lpdf_unconstrained.h + target_lpdf_unconstrained.cc ) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 35ac6cde4..c8163dfbd 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -4,8 +4,7 @@ #include "src/hierarchies/abstract_hierarchy.h" #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" - -class target_lpdf_unconstrained; +#include "src/hierarchies/updaters/target_lpdf_unconstrained.h" class AbstractUpdater { public: @@ -14,8 +13,7 @@ class AbstractUpdater { virtual bool is_conjugate() const { return false; }; virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, - bool update_params, - target_lpdf_unconstrained &target_lpdf) = 0; + bool update_params) = 0; virtual void compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) { diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h index 46cc1bc50..9cc662395 100644 --- a/src/hierarchies/updaters/metropolis_updater.h +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -8,8 +8,8 @@ template class MetropolisUpdater : public AbstractUpdater { public: void draw(AbstractLikelihood &like, AbstractPriorModel &prior, - bool update_params, - target_lpdf_unconstrained &target_lpdf) override { + bool update_params) override { + target_lpdf_unconstrained target_lpdf(&like, &prior); Eigen::VectorXd curr_state = like.get_unconstrained_state(); Eigen::VectorXd prop_state = static_cast(this)->sample_proposal( From 79132b27031c87086eadb59c1c477d777d597c31 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Fri, 28 Jan 2022 15:17:44 +0100 Subject: [PATCH 114/317] removed updater from template --- run_mcmc.cc | 2 -- src/hierarchies/base_hierarchy.h | 18 ++++++++++++------ src/hierarchies/nnig_hierarchy.h | 16 +++++++++------- src/hierarchies/updaters/CMakeLists.txt | 6 +++--- src/hierarchies/updaters/abstract_updater.h | 1 - src/hierarchies/updaters/conjugate_updater.h | 2 -- 6 files changed, 24 insertions(+), 21 deletions(-) diff --git a/run_mcmc.cc b/run_mcmc.cc index 6ad72ae43..b106e25d0 100644 --- a/run_mcmc.cc +++ b/run_mcmc.cc @@ -180,8 +180,6 @@ int main(int argc, char *argv[]) { std::cout << "hier->prior: \n" << hier->get_mutable_prior()->DebugString() << std::endl; - auto updater = std::make_shared(0.25); - hier->set_updater(updater); hier->initialize(); // Read data matrices diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 2bcf1acdd..47b7b652b 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -31,12 +31,12 @@ //! @tparam Hyperparams Class name of the container for hyperprior parameters //! @tparam Prior Class name of the container for prior parameters -template +template class BaseHierarchy : public AbstractHierarchy { protected: std::shared_ptr like = std::make_shared(); std::shared_ptr prior = std::make_shared(); - std::shared_ptr updater = std::make_shared(); + std::shared_ptr updater; public: using HyperParams = decltype(prior->get_hypers()); @@ -52,6 +52,8 @@ class BaseHierarchy : public AbstractHierarchy { } if (updater_) { set_updater(updater_); + } else { + static_cast(this)->set_default_updater(); } } @@ -64,9 +66,11 @@ class BaseHierarchy : public AbstractHierarchy { prior = std::static_pointer_cast(prior_); } void set_updater(std::shared_ptr updater_) override { - updater = std::static_pointer_cast(updater_); + updater = updater_; }; + virtual void set_default_updater() = 0; + std::shared_ptr get_likelihood() override { return like; } @@ -109,7 +113,8 @@ class BaseHierarchy : public AbstractHierarchy { Eigen::VectorXd prior_pred_lpdf_grid( const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) + const override { Eigen::VectorXd lpdf(data.rows()); if (covariates.cols() == 0) { // Pass null value as covariate @@ -141,7 +146,8 @@ class BaseHierarchy : public AbstractHierarchy { Eigen::VectorXd conditional_pred_lpdf_grid( const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { + const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) + const override { Eigen::VectorXd lpdf(data.rows()); if (covariates.cols() == 0) { // Pass null value as covariate @@ -217,7 +223,7 @@ class BaseHierarchy : public AbstractHierarchy { std::set get_data_idx() const override { return like->get_data_idx(); }; - google::protobuf::Message *get_mutable_prior() { + google::protobuf::Message *get_mutable_prior() override { return prior->get_mutable_prior(); }; diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index 645af2c96..fa797c5ac 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -15,22 +15,24 @@ #include "base_hierarchy.h" #include "likelihoods/uni_norm_likelihood.h" #include "priors/nig_prior_model.h" -// #include "updaters/nnig_updater.h" -#include "updaters/mala_updater.h" -#include "updaters/random_walk_updater.h" +#include "updaters/nnig_updater.h" -class NNIGHierarchy : public BaseHierarchy { +class NNIGHierarchy + : public BaseHierarchy { public: ~NNIGHierarchy() = default; - using BaseHierarchy::BaseHierarchy; + using BaseHierarchy::BaseHierarchy; bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNIG; } + void set_default_updater() override { + updater = std::make_shared(); + } + void initialize_state() override { // Get hypers auto hypers = prior->get_hypers(); diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index e6f1faf5d..8ce0f182d 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -3,9 +3,9 @@ target_sources(bayesmix PUBLIC conjugate_updater.h mala_updater.h metropolis_updater.h - # nnig_updater.h - # nnig_updater.cc - # random_walk_updater.h + nnig_updater.h + nnig_updater.cc + random_walk_updater.h target_lpdf_unconstrained.h target_lpdf_unconstrained.cc ) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index c8163dfbd..188c4e46c 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -1,7 +1,6 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_ -#include "src/hierarchies/abstract_hierarchy.h" #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" #include "src/hierarchies/updaters/target_lpdf_unconstrained.h" diff --git a/src/hierarchies/updaters/conjugate_updater.h b/src/hierarchies/updaters/conjugate_updater.h index 99552686e..b4f1cefe5 100644 --- a/src/hierarchies/updaters/conjugate_updater.h +++ b/src/hierarchies/updaters/conjugate_updater.h @@ -14,8 +14,6 @@ class ConjugateUpdater : public AbstractUpdater { bool is_conjugate() const override { return true; }; void draw(AbstractLikelihood& like, AbstractPriorModel& prior, bool update_params) override; - virtual void compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) = 0; protected: Likelihood& downcast_likelihood(AbstractLikelihood& like_); From e7c435896a5b7877572740324ac27567fe7f9e64 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 28 Jan 2022 21:13:32 +0100 Subject: [PATCH 115/317] MultiNormLikelihood class defined --- src/hierarchies/likelihoods/CMakeLists.txt | 2 + .../likelihoods/multi_norm_likelihood.cc | 56 +++++++++++++++++++ .../likelihoods/multi_norm_likelihood.h | 41 ++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 src/hierarchies/likelihoods/multi_norm_likelihood.cc create mode 100644 src/hierarchies/likelihoods/multi_norm_likelihood.h diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 3e5495073..a30eb192f 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -5,4 +5,6 @@ target_sources(bayesmix states.h uni_norm_likelihood.h uni_norm_likelihood.cc + multi_norm_likelihood.h + multi_norm_likelihood.cc ) diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.cc b/src/hierarchies/likelihoods/multi_norm_likelihood.cc new file mode 100644 index 000000000..353e5df23 --- /dev/null +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.cc @@ -0,0 +1,56 @@ +#include "multi_norm_likelihood.h" + +#include "src/utils/distributions.h" +#include "src/utils/eigen_utils.h" +#include "src/utils/proto_utils.h" + +double MultiNormLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { + return bayesmix::multi_normal_prec_lpdf(datum, state.mean, state.prec_chol, + state.prec_logdet); +} + +void MultiNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, bool add) { + + // Prepare buffer in case dim is not defined yet + if (!dim) { + dim = datum.size(); + data_sum = Eigen::VectorXd::Zero(dim); + data_sum_squares = Eigen::MatrixXd::Zero(dim,dim); + } + + // Updates + if (add) { + data_sum += datum.transpose(); + data_sum_squares += datum.transpose() * datum; + } else { + data_sum -= datum.transpose(); + data_sum_squares -= datum.transpose() * datum; + } +} + +void MultiNormLikelihood::set_state_from_proto( + const google::protobuf::Message &state_, bool update_card) { + auto &statecast = downcast_state(state_); + state.mean = to_eigen(statecast.multi_ls_state().mean()); + state.prec = to_eigen(statecast.multi_ls_state().prec()); + state.prec_chol = to_eigen(statecast.multi_ls_state().prec_chol()); + Eigen::VectorXd diag = state.prec_chol.diagonal(); + state.prec_logdet = 2 * log(diag.array()).sum(); + if (update_card) set_card(statecast.cardinality()); +} + +std::shared_ptr +MultiNormLikelihood::get_state_proto() const { + bayesmix::MultiLSState state_; + bayesmix::to_proto(state.mean, state_.mutable_mean()); + bayesmix::to_proto(state.prec, state_.mutable_prec()); + bayesmix::to_proto(state.prec_chol, state_.mutable_prec_chol()); + auto out = std::make_shared(); + out->mutable_multi_ls_state()->CopyFrom(state_); + return out; +} + +void MultiNormLikelihood::clear_summary_statistics() { + data_sum = Eigen::VectorXd::Zero(dim); + data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); +} diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h new file mode 100644 index 000000000..4d8a5a3e8 --- /dev/null +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -0,0 +1,41 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_MULTI_NORM_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_MULTI_NORM_LIKELIHOOD_H_ + +#include + +#include +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states.h" + +class MultiNormLikelihood + : public BaseLikelihood { + public: + MultiNormLikelihood() = default; + ~MultiNormLikelihood() = default; + bool is_multivariate() const override { return true; }; + bool is_dependent() const override { return false; }; + void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) override; + void clear_summary_statistics() override; + + unsigned int get_dim() const { return dim; }; + Eigen::VectorXd get_data_sum() const { return data_sum; }; + Eigen::MatrixXd get_data_sum_squares() const { return data_sum_squares; }; + + protected: + std::shared_ptr get_state_proto() + const override; + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + + unsigned int dim; + Eigen::VectorXd data_sum; + Eigen::MatrixXd data_sum_squares; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_MULTI_NORM_LIKELIHOOD_H_ From d1302b52a366a41cd37504ff25ed3652069e420a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 28 Jan 2022 21:13:55 +0100 Subject: [PATCH 116/317] NWPriorModel class defined --- src/hierarchies/priors/CMakeLists.txt | 2 + src/hierarchies/priors/nw_prior_model.cc | 268 +++++++++++++++++++++++ src/hierarchies/priors/nw_prior_model.h | 46 ++++ 3 files changed, 316 insertions(+) create mode 100644 src/hierarchies/priors/nw_prior_model.cc create mode 100644 src/hierarchies/priors/nw_prior_model.h diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt index 7ed26bdeb..a283b7762 100644 --- a/src/hierarchies/priors/CMakeLists.txt +++ b/src/hierarchies/priors/CMakeLists.txt @@ -7,4 +7,6 @@ target_sources(bayesmix nig_prior_model.cc nxig_prior_model.h nxig_prior_model.cc + nw_prior_model.h + nw_prior_model.cc ) diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc new file mode 100644 index 000000000..6bdecb724 --- /dev/null +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -0,0 +1,268 @@ +#include "nw_prior_model.h" + +#include "src/utils/eigen_utils.h" +#include "src/utils/proto_utils.h" +#include "src/utils/distributions.h" + +void NWPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers.mean = bayesmix::to_eigen(prior->fixed_values().mean()); + dim = hypers.mean.size(); + hypers.var_scaling = prior->fixed_values().var_scaling(); + hypers.scale = bayesmix::to_eigen(prior->fixed_values().scale()); + hypers.deg_free = prior->fixed_values().deg_free(); + // Check validity + if (hypers.var_scaling <= 0) { + throw std::invalid_argument("Variance-scaling parameter must be > 0"); + } + if (dim != hypers.scale.rows()) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + if (hypers.deg_free <= dim - 1) { + throw std::invalid_argument("Degrees of freedom parameter is not valid"); + } + } + + else if (prior->has_normal_mean_prior()) { + // Get hyperparameters + Eigen::VectorXd mu00 = + bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().mean()); + dim = mu00.size(); + Eigen::MatrixXd sigma00 = + bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().var()); + double lambda0 = prior->normal_mean_prior().var_scaling(); + Eigen::MatrixXd tau0 = + bayesmix::to_eigen(prior->normal_mean_prior().scale()); + double nu0 = prior->normal_mean_prior().deg_free(); + // Check validity + dim = mu00.size(); + if (sigma00.rows() != dim or tau0.rows() != dim) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + bayesmix::check_spd(sigma00); + if (lambda0 <= 0) { + throw std::invalid_argument("Variance-scaling parameter must be > 0"); + } + bayesmix::check_spd(tau0); + if (nu0 <= dim - 1) { + throw std::invalid_argument("Degrees of freedom parameter is not valid"); + } + // Set initial values + hypers.mean = mu00; + hypers.var_scaling = lambda0; + hypers.scale = tau0; + hypers.deg_free = nu0; + } + + else if (prior->has_ngiw_prior()) { + // Get hyperparameters: + // for mu0 + Eigen::VectorXd mu00 = + bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); + unsigned int dim = mu00.size(); + Eigen::MatrixXd sigma00 = + bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); + // for lambda0 + double alpha00 = prior->ngiw_prior().var_scaling_prior().shape(); + double beta00 = prior->ngiw_prior().var_scaling_prior().rate(); + // for tau0 + double nu00 = prior->ngiw_prior().scale_prior().deg_free(); + Eigen::MatrixXd tau00 = + bayesmix::to_eigen(prior->ngiw_prior().scale_prior().scale()); + // for nu0 + double nu0 = prior->ngiw_prior().deg_free(); + // Check validity: + // dimensionality + if (sigma00.rows() != dim or tau00.rows() != dim) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + // for mu0 + bayesmix::check_spd(sigma00); + // for lambda0 + if (alpha00 <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (beta00 <= 0) { + throw std::invalid_argument("Rate parameter must be > 0"); + } + // for tau0 + if (nu00 <= 0) { + throw std::invalid_argument("Degrees of freedom parameter must be > 0"); + } + bayesmix::check_spd(tau00); + // check nu0 + if (nu0 <= dim - 1) { + throw std::invalid_argument("Degrees of freedom parameter is not valid"); + } + // Set initial values + hypers.mean = mu00; + hypers.var_scaling = alpha00 / beta00; + hypers.scale = tau00 / (nu00 + dim + 1); + hypers.deg_free = nu0; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } + hypers.scale_inv = stan::math::inverse_spd(hypers.scale); + hypers.scale_chol = Eigen::LLT(hypers.scale).matrixU(); +} + +double NWPriorModel::lpdf(const google::protobuf::Message &state_) { + auto &state = downcast_state(state_).multi_ls_state(); + Eigen::VectorXd mean = bayesmix::to_eigen(state.mean()); + Eigen::MatrixXd prec = bayesmix::to_eigen(state.prec()); + double target = stan::math::multi_normal_prec_lpdf(mean, hypers.mean, prec * hypers.var_scaling) + + stan::math::wishart_lpdf(prec, hypers.deg_free, hypers.scale); + return target; +} + +std::shared_ptr NWPriorModel::sample( + bool use_post_hypers) { + + auto &rng = bayesmix::Rng::Instance().get(); + + Hyperparams::NW params = use_post_hypers ? post_hypers : hypers; + + Eigen::MatrixXd tau_new = + stan::math::wishart_rng(params.deg_free, params.scale, rng); + + // Update state + State::MultiLS out; + out.mean = stan::math::multi_normal_prec_rng( + params.mean, tau_new * params.var_scaling, rng); + write_prec_to_state(tau_new, &out); + + // Translate to proto + bayesmix::Vector mean_proto; + bayesmix::Matrix prec_proto, prec_chol_proto; + bayesmix::to_proto(out.mean, &mean_proto); + bayesmix::to_proto(out.prec, &prec_proto); + bayesmix::to_proto(out.prec_chol, &prec_chol_proto); + + // Make output state + bayesmix::AlgorithmState::ClusterState state; + state.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); + state.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); + state.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom(prec_chol_proto); + return std::make_shared(state); +}; + +void NWPriorModel::update_hypers( + const std::vector &states) { + + auto &rng = bayesmix::Rng::Instance().get(); + + if (prior->has_fixed_values()) { + return; + } + + else if (prior->has_normal_mean_prior()) { + // Get hyperparameters + Eigen::VectorXd mu00 = + bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().mean()); + Eigen::MatrixXd sigma00 = + bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().var()); + double lambda0 = prior->normal_mean_prior().var_scaling(); + // Compute posterior hyperparameters + Eigen::MatrixXd sigma00inv = stan::math::inverse_spd(sigma00); + Eigen::MatrixXd prec = Eigen::MatrixXd::Zero(dim, dim); + Eigen::VectorXd num = Eigen::MatrixXd::Zero(dim, 1); + for (auto &st : states) { + Eigen::MatrixXd prec_i = bayesmix::to_eigen(st.multi_ls_state().prec()); + prec += prec_i; + num += prec_i * bayesmix::to_eigen(st.multi_ls_state().mean()); + } + prec = hypers.var_scaling * prec + sigma00inv; + num = hypers.var_scaling * num + sigma00inv * mu00; + Eigen::VectorXd mu_n = prec.llt().solve(num); + // Update hyperparameters with posterior sampling + hypers.mean = stan::math::multi_normal_prec_rng(mu_n, prec, rng); + } + + else if (prior->has_ngiw_prior()) { + // Get hyperparameters: + // for mu0 + Eigen::VectorXd mu00 = + bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); + Eigen::MatrixXd sigma00 = + bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); + // for lambda0 + double alpha00 = prior->ngiw_prior().var_scaling_prior().shape(); + double beta00 = prior->ngiw_prior().var_scaling_prior().rate(); + // for tau0 + double nu00 = prior->ngiw_prior().scale_prior().deg_free(); + Eigen::MatrixXd tau00 = + bayesmix::to_eigen(prior->ngiw_prior().scale_prior().scale()); + // Compute posterior hyperparameters + Eigen::MatrixXd sigma00inv = stan::math::inverse_spd(sigma00); + Eigen::MatrixXd tau_n = Eigen::MatrixXd::Zero(dim, dim); + Eigen::VectorXd num = Eigen::MatrixXd::Zero(dim, 1); + double beta_n = 0.0; + for (auto &st : states) { + Eigen::VectorXd mean = bayesmix::to_eigen(st.multi_ls_state().mean()); + Eigen::MatrixXd prec = bayesmix::to_eigen(st.multi_ls_state().prec()); + tau_n += prec; + num += prec * mean; + beta_n += + (hypers.mean - mean).transpose() * prec * (hypers.mean - mean); + } + Eigen::MatrixXd prec_n = hypers.var_scaling * tau_n + sigma00inv; + tau_n += tau00; + num = hypers.var_scaling * num + sigma00inv * mu00; + beta_n = beta00 + 0.5 * beta_n; + Eigen::MatrixXd sig_n = stan::math::inverse_spd(prec_n); + Eigen::VectorXd mu_n = sig_n * num; + double alpha_n = alpha00 + 0.5 * states.size(); + double nu_n = nu00 + states.size() * hypers.deg_free; + // Update hyperparameters with posterior random Gibbs sampling + hypers.mean = stan::math::multi_normal_rng(mu_n, sig_n, rng); + hypers.var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); + hypers.scale = stan::math::inv_wishart_rng(nu_n, tau_n, rng); + hypers.scale_inv = stan::math::inverse_spd(hypers.scale); + hypers.scale_chol = Eigen::LLT(hypers.scale).matrixU(); + } + + else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void NWPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).nnw_state(); + hypers.mean = bayesmix::to_eigen(hyperscast.mean()); + hypers.var_scaling = hyperscast.var_scaling(); + hypers.deg_free = hyperscast.deg_free(); + hypers.scale = bayesmix::to_eigen(hyperscast.scale()); + hypers.scale_inv = stan::math::inverse_spd(hypers.scale); + hypers.scale_chol = Eigen::LLT(hypers.scale).matrixU(); +} + +std::shared_ptr +NWPriorModel::get_hypers_proto() const { + + // Translate to proto + bayesmix::Vector mean_proto; + bayesmix::Matrix scale_proto; + bayesmix::to_proto(hypers.mean, &mean_proto); + bayesmix::to_proto(hypers.scale, &scale_proto); + + // Make output state and return + auto out = std::make_shared(); + out->mutable_nnw_state()->mutable_mean()->CopyFrom(mean_proto); + out->mutable_nnw_state()->set_var_scaling(hypers.var_scaling); + out->mutable_nnw_state()->set_deg_free(hypers.deg_free); + out->mutable_nnw_state()->mutable_scale()->CopyFrom(scale_proto); + return out; +} + +void NWPriorModel::write_prec_to_state(const Eigen::MatrixXd &prec_, State::MultiLS *out) { + out->prec = prec_; + // Update prec utilities + out->prec_chol = Eigen::LLT(prec_).matrixU(); + Eigen::VectorXd diag = out->prec_chol.diagonal(); + out->prec_logdet = 2 * log(diag.array()).sum(); +} \ No newline at end of file diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h new file mode 100644 index 000000000..3d4ba3e5b --- /dev/null +++ b/src/hierarchies/priors/nw_prior_model.h @@ -0,0 +1,46 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_NW_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_NW_PRIOR_MODEL_H_ + +// #include + +#include +#include +#include +#include + +// #include "algorithm_state.pb.h" +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +class NWPriorModel : public BasePriorModel { + public: + NWPriorModel() = default; + ~NWPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + std::shared_ptr sample( + bool use_post_hypers) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + protected: + + std::shared_ptr get_hypers_proto() + const override; + + void initialize_hypers() override; + + void write_prec_to_state(const Eigen::MatrixXd &prec_, State::MultiLS *out); + + unsigned int dim; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_NW_PRIOR_MODEL_H_ From c98e32ff46689db2c4652af1f75f59e19d356507 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 28 Jan 2022 21:14:16 +0100 Subject: [PATCH 117/317] Add test for MultiNormLikelihood --- test/likelihoods.cc | 82 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/test/likelihoods.cc b/test/likelihoods.cc index d5af2a024..02eeeb9f4 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -7,6 +7,9 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" +#include "src/hierarchies/likelihoods/multi_norm_likelihood.h" + +#include "src/utils/proto_utils.h" #include "src/utils/rng.h" TEST(uni_norm_likelihood, set_get_state) { @@ -74,3 +77,82 @@ TEST(uni_norm_likelihood, eval_lpdf) { // Check if they coincides ASSERT_EQ(evals, evals_copy); } + +TEST(multi_norm_likelihood, set_get_state) { + // Instance + auto like = std::make_shared(); + + // Prepare buffers + bayesmix::MultiLSState state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state + Eigen::Vector2d mu = {5.5, 5.5}; //mu << 5.5, 5.5; + Eigen::Matrix2d prec = Eigen::Matrix2d::Identity(); + bayesmix::Vector mean_proto; + bayesmix::Matrix prec_proto; + bayesmix::to_proto(mu, &mean_proto); + bayesmix::to_proto(prec, &prec_proto); + set_state_.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); + set_state_.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); + set_state_.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom(prec_proto); + + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(multi_norm_likelihood, add_remove_data) { + // Instance + auto like = std::make_shared(); + + // Add new datum to likelihood + Eigen::RowVectorXd datum(2); + datum << 5.5, 5.5; + like->add_datum(0, datum); + + // Check if cardinality is augmented + ASSERT_EQ(like->get_card(), 1); + + // Remove datum from likelihood + like->remove_datum(0, datum); + + // Check if cardinality is reduced + ASSERT_EQ(like->get_card(), 0); +} + +TEST(multi_norm_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::AlgorithmState::ClusterState clust_state_; + Eigen::Vector2d mu = {5.5, 5.5}; //mu << 5.5, 5.5; + Eigen::Matrix2d prec = Eigen::Matrix2d::Identity(); + bayesmix::Vector mean_proto; + bayesmix::Matrix prec_proto; + bayesmix::to_proto(mu, &mean_proto); + bayesmix::to_proto(prec, &prec_proto); + clust_state_.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); + clust_state_.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); + clust_state_.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom(prec_proto); + like->set_state_from_proto(clust_state_); + + // Data matrix on which evaluate the likelihood + Eigen::MatrixXd data(3,2); + data.row(0) << 4.5, 4.5; + data.row(1) << 5.1, 5.1; + data.row(2) << 2.5, 2.5; + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} \ No newline at end of file From 94684c7e612db99906226ccc88632c0cf960e6eb Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 28 Jan 2022 21:14:28 +0100 Subject: [PATCH 118/317] Add test for NWPriorModel --- test/prior_models.cc | 143 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/test/prior_models.cc b/test/prior_models.cc index dc392b9e0..31756f4bf 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -6,8 +6,11 @@ #include "algorithm_state.pb.h" #include "hierarchy_prior.pb.h" + #include "src/hierarchies/priors/nig_prior_model.h" #include "src/hierarchies/priors/nxig_prior_model.h" +#include "src/hierarchies/priors/nw_prior_model.h" +#include "src/utils/proto_utils.h" TEST(nig_prior_model, set_get_hypers) { // Instance @@ -204,6 +207,146 @@ TEST(nxig_prior_model, sample) { auto state1 = prior->sample(!use_post_hypers); auto state2 = prior->sample(!use_post_hypers); + // Check if they coincides + ASSERT_TRUE(state1->DebugString() != state2->DebugString()); +} + +TEST(nw_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::NWDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + bayesmix::Vector mean_proto; + bayesmix::Matrix scale_proto; + bayesmix::to_proto(Eigen::Vector2d({5.5, 5.5}), &mean_proto); + bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale_proto); + + // Prepare hypers + hypers_.mutable_mean()->CopyFrom(mean_proto); + hypers_.set_deg_free(4); + hypers_.set_var_scaling(0.1); + hypers_.mutable_scale()->CopyFrom(scale_proto); + set_state_.mutable_nnw_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(nw_prior_model, fixed_values_prior) { + // Prepare buffers + bayesmix::NNWPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior + bayesmix::Vector mean_proto; + bayesmix::Matrix scale_proto; + bayesmix::to_proto(Eigen::Vector2d({5.5, 5.5}), &mean_proto); + bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale_proto); + prior.mutable_fixed_values()->mutable_mean()->CopyFrom(mean_proto); + prior.mutable_fixed_values()->set_var_scaling(0.1); + prior.mutable_fixed_values()->set_deg_free(4); + prior.mutable_fixed_values()->mutable_scale()->CopyFrom(scale_proto); + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnw_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.nnw_state().DebugString()); + } +} + +TEST(nw_prior_model, normal_mean_prior) { + // Prepare buffers + bayesmix::NNWPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + + // Set Normal prior on the mean + Eigen::Vector2d mu00 = Eigen::Vector2d::Zero(); + Eigen::Matrix2d Sigma00 = Eigen::Matrix2d::Identity(); + bayesmix::Vector mu00_proto; + bayesmix::Matrix Sigma00_proto, scale_proto; + bayesmix::to_proto(mu00, &mu00_proto); + bayesmix::to_proto(Sigma00, &Sigma00_proto); + bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale_proto); + prior.mutable_normal_mean_prior()->mutable_mean_prior()->mutable_mean()->CopyFrom(mu00_proto); + prior.mutable_normal_mean_prior()->mutable_mean_prior()->mutable_var()->CopyFrom(Sigma00_proto); + prior.mutable_normal_mean_prior()->set_var_scaling(0.1); + prior.mutable_normal_mean_prior()->set_deg_free(4); + prior.mutable_normal_mean_prior()->mutable_scale()->CopyFrom(scale_proto); + + // Prepare some fictional states + std::vector states(4); + for (int i = 0; i < states.size(); i++) { + Eigen::Vector2d mean = (9.0 + i) * Eigen::Vector2d::Ones(); + bayesmix::Vector tmp; bayesmix::to_proto(mean, &tmp); + states[i].mutable_multi_ls_state()->mutable_mean()->CopyFrom(tmp); + states[i].mutable_multi_ls_state()->mutable_prec()->CopyFrom(scale_proto); + states[i].mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom(scale_proto); + } + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Update hypers in light of current states + prior_model->update_hypers(states); + prior_model->write_hypers_to_proto(&prior_out); + Eigen::Vector2d mean_out = bayesmix::to_eigen(prior_out.nnw_state().mean()); + + // Check + for (size_t i = 0; i < mu00.size(); i++) { + ASSERT_GT(mean_out(i), mu00(i)); + } +} + +TEST(nw_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + bool use_post_hypers = true; + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + bayesmix::Vector mean; + bayesmix::Matrix scale; + bayesmix::to_proto(Eigen::Vector2d({5.2,5.2}), &mean); + bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale); + hypers_proto.mutable_nnw_state()->mutable_mean()->CopyFrom(mean); + hypers_proto.mutable_nnw_state()->set_var_scaling(0.1); + hypers_proto.mutable_nnw_state()->set_deg_free(4); + hypers_proto.mutable_nnw_state()->mutable_scale()->CopyFrom(scale); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(!use_post_hypers); + auto state2 = prior->sample(!use_post_hypers); + // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); } \ No newline at end of file From 4ba3c636ad6b3931dd9d2bdd4a5f9dc943aea3af Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 29 Jan 2022 15:05:15 +0100 Subject: [PATCH 119/317] Improved API --- src/hierarchies/likelihoods/multi_norm_likelihood.cc | 8 ++------ src/hierarchies/likelihoods/multi_norm_likelihood.h | 5 +++++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.cc b/src/hierarchies/likelihoods/multi_norm_likelihood.cc index 353e5df23..f187bed6e 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.cc @@ -12,12 +12,8 @@ double MultiNormLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const void MultiNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, bool add) { // Prepare buffer in case dim is not defined yet - if (!dim) { - dim = datum.size(); - data_sum = Eigen::VectorXd::Zero(dim); - data_sum_squares = Eigen::MatrixXd::Zero(dim,dim); - } - + if (!dim) + set_dim(datum.size()); // Updates if (add) { data_sum += datum.transpose(); diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index 4d8a5a3e8..f61696ae8 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -23,6 +23,11 @@ class MultiNormLikelihood bool update_card = true) override; void clear_summary_statistics() override; + void set_dim(unsigned int dim_) { + dim = dim_; + data_sum = Eigen::VectorXd::Zero(dim); + data_sum_squares = Eigen::MatrixXd::Zero(dim,dim); + }; unsigned int get_dim() const { return dim; }; Eigen::VectorXd get_data_sum() const { return data_sum; }; Eigen::MatrixXd get_data_sum_squares() const { return data_sum_squares; }; From 489384f5e51862bbcf32d76dfe32f1375e0fd9bb Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 29 Jan 2022 15:05:33 +0100 Subject: [PATCH 120/317] Add updater for NNW hierarchy --- src/hierarchies/updaters/nnw_updater.cc | 41 +++++++++++++++++++++++++ src/hierarchies/updaters/nnw_updater.h | 17 ++++++++++ 2 files changed, 58 insertions(+) create mode 100644 src/hierarchies/updaters/nnw_updater.cc create mode 100644 src/hierarchies/updaters/nnw_updater.h diff --git a/src/hierarchies/updaters/nnw_updater.cc b/src/hierarchies/updaters/nnw_updater.cc new file mode 100644 index 000000000..5c6807ab1 --- /dev/null +++ b/src/hierarchies/updaters/nnw_updater.cc @@ -0,0 +1,41 @@ +#include "nnw_updater.h" + +#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/priors/hyperparams.h" + +void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + int card = likecast.get_card(); + Eigen::VectorXd data_sum = likecast.get_data_sum(); + Eigen::MatrixXd data_sum_squares = likecast.get_data_sum_squares(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + priorcast.set_posterior_hypers(hypers); + return; + } + + // Compute posterior hyperparameters + Hyperparams::NW post_params; + post_params.var_scaling = hypers.var_scaling + card; + post_params.deg_free = hypers.deg_free + card; + Eigen::VectorXd mubar = data_sum.array() / card; // sample mean + post_params.mean = (hypers.var_scaling * hypers.mean + card * mubar) / + (hypers.var_scaling + card); + // Compute tau_n + Eigen::MatrixXd tau_temp = + data_sum_squares - card * mubar * mubar.transpose(); + tau_temp += (card * hypers.var_scaling / (card + hypers.var_scaling)) * + (mubar - hypers.mean) * (mubar - hypers.mean).transpose(); + post_params.scale_inv = tau_temp + hypers.scale_inv; + post_params.scale = stan::math::inverse_spd(post_params.scale_inv); + post_params.scale_chol = Eigen::LLT(post_params.scale).matrixU(); + priorcast.set_posterior_hypers(post_params); + return; +}; diff --git a/src/hierarchies/updaters/nnw_updater.h b/src/hierarchies/updaters/nnw_updater.h new file mode 100644 index 000000000..18f6a824f --- /dev/null +++ b/src/hierarchies/updaters/nnw_updater.h @@ -0,0 +1,17 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ + +#include "conjugate_updater.h" +#include "src/hierarchies/likelihoods/multi_norm_likelihood.h" +#include "src/hierarchies/priors/nw_prior_model.h" + +class NNWUpdater : public ConjugateUpdater { + public: + NNWUpdater() = default; + ~NNWUpdater() = default; + + void compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) override; +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ From 33d9bd9b1a799c0b955428978227ad084d1718b2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 29 Jan 2022 15:05:48 +0100 Subject: [PATCH 121/317] Improved API --- src/hierarchies/priors/nw_prior_model.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h index 3d4ba3e5b..e2dc2482d 100644 --- a/src/hierarchies/priors/nw_prior_model.h +++ b/src/hierarchies/priors/nw_prior_model.h @@ -31,6 +31,10 @@ class NWPriorModel : public BasePriorModel get_hypers_proto() @@ -38,8 +42,6 @@ class NWPriorModel : public BasePriorModel Date: Sat, 29 Jan 2022 15:06:08 +0100 Subject: [PATCH 122/317] Add NNW hierarchy --- src/hierarchies/CMakeLists.txt | 2 +- src/hierarchies/load_hierarchies.h | 10 +- src/hierarchies/nnw_hierarchy.h | 69 +++++++++++++ src/hierarchies/updaters/CMakeLists.txt | 2 + src/includes.h | 2 +- test/hierarchies.cc | 130 ++++++++++++------------ 6 files changed, 143 insertions(+), 72 deletions(-) create mode 100644 src/hierarchies/nnw_hierarchy.h diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index 537299476..173c97a86 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -4,10 +4,10 @@ target_sources(bayesmix base_hierarchy.h nnig_hierarchy.h nnxig_hierarchy.h + nnw_hierarchy.h # conjugate_hierarchy.h # lin_reg_uni_hierarchy.h # lin_reg_uni_hierarchy.cc - # nnw_hierarchy.h # nnw_hierarchy.cc ) diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index e42a3af92..1a3006675 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -8,7 +8,7 @@ #include "hierarchy_id.pb.h" // #include "lin_reg_uni_hierarchy.h" #include "nnig_hierarchy.h" -// #include "nnw_hierarchy.h" +#include "nnw_hierarchy.h" #include "nnxig_hierarchy.h" #include "src/runtime/factory.h" @@ -29,16 +29,16 @@ __attribute__((constructor)) static void load_hierarchies() { Builder NNxIGbuilder = []() { return std::make_shared(); }; - // Builder NNWbuilder = []() { - // return std::make_shared(); - // }; + Builder NNWbuilder = []() { + return std::make_shared(); + }; // Builder LinRegUnibuilder = []() { // return std::make_shared(); // }; factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder); factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); - // factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); + factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); // factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); } diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h new file mode 100644 index 000000000..dab6fa00a --- /dev/null +++ b/src/hierarchies/nnw_hierarchy.h @@ -0,0 +1,69 @@ +#ifndef BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ + +// #include + +// #include +// #include +// #include + +// #include "algorithm_state.pb.h" +// #include "conjugate_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "src/utils/distributions.h" +// #include "hierarchy_prior.pb.h" + +#include "base_hierarchy.h" +#include "likelihoods/multi_norm_likelihood.h" +#include "priors/nw_prior_model.h" +#include "updaters/nnw_updater.h" + +class NNWHierarchy : public BaseHierarchy { + public: + NNWHierarchy() = default; + ~NNWHierarchy() = default; + + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::NNW; + } + + void initialize_state() override { + // Initialize likelihood dimension to prior one + like->set_dim(prior->get_dim()); + // Get hypers and data dimension + auto hypers = prior->get_hypers(); + unsigned int dim = like->get_dim(); + // Initialize likelihood state + State::MultiLS state; + state.mean = hypers.mean; + prior->write_prec_to_state(hypers.var_scaling * Eigen::MatrixXd::Identity(dim, dim), &state); + like->set_state(state); + }; + + double marg_lpdf(const HyperParams ¶ms, + const Eigen::RowVectorXd &datum) const override { + HyperParams pred_params = get_predictive_t_parameters(params); + Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); + double logdet = 2 * log(diag.array()).sum(); + return bayesmix::multi_student_t_invscale_lpdf( + datum, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, + logdet); + } + + HyperParams get_predictive_t_parameters(const HyperParams ¶ms) const { + // Compute dof and scale of marginal distribution + unsigned int dim = like->get_dim(); + double nu_n = params.deg_free - dim + 1; + double coeff = (params.var_scaling + 1) / (params.var_scaling * nu_n); + Eigen::MatrixXd scale_chol_n = params.scale_chol / std::sqrt(coeff); + // Return predictive t parameters + HyperParams out; + out.mean = params.mean; + out.deg_free = nu_n; + out.scale_chol = scale_chol_n; + return out; + } +}; + +#endif // BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 3efd34b34..9e1841b9e 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -6,4 +6,6 @@ target_sources(bayesmix nnig_updater.cc nnxig_updater.h nnxig_updater.cc + nnw_updater.h + nnw_updater.cc ) diff --git a/src/includes.h b/src/includes.h index ae41aeefe..69daa7296 100644 --- a/src/includes.h +++ b/src/includes.h @@ -12,7 +12,7 @@ // #include "hierarchies/lin_reg_uni_hierarchy.h" #include "hierarchies/load_hierarchies.h" #include "hierarchies/nnig_hierarchy.h" -// #include "hierarchies/nnw_hierarchy.h" +#include "hierarchies/nnw_hierarchy.h" #include "hierarchies/nnxig_hierarchy.h" #include "mixings/dirichlet_mixing.h" #include "mixings/load_mixings.h" diff --git a/test/hierarchies.cc b/test/hierarchies.cc index b23819186..359fb4b04 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -7,7 +7,7 @@ #include "ls_state.pb.h" // #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" -// #include "src/hierarchies/nnw_hierarchy.h" +#include "src/hierarchies/nnw_hierarchy.h" #include "src/hierarchies/nnxig_hierarchy.h" #include "src/utils/proto_utils.h" #include "src/utils/rng.h" @@ -69,71 +69,71 @@ TEST(nnig_hierarchy, sample_given_data) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -// TEST(nnwhierarchy, draw) { -// auto hier = std::make_shared(); -// bayesmix::NNWPrior prior; -// Eigen::Vector2d mu0; -// mu0 << 5.5, 5.5; -// bayesmix::Vector mu0_proto; -// bayesmix::to_proto(mu0, &mu0_proto); -// double lambda0 = 0.2; -// double nu0 = 5.0; -// Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; -// bayesmix::Matrix tau0_proto; -// bayesmix::to_proto(tau0, &tau0_proto); -// *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; -// prior.mutable_fixed_values()->set_var_scaling(lambda0); -// prior.mutable_fixed_values()->set_deg_free(nu0); -// *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; -// hier->get_mutable_prior()->CopyFrom(prior); -// hier->initialize(); - -// auto hier2 = hier->clone(); -// hier2->sample_prior(); - -// bayesmix::AlgorithmState out; -// bayesmix::AlgorithmState::ClusterState* clusval = -// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 -// = out.add_cluster_states(); hier->write_state_to_proto(clusval); -// hier2->write_state_to_proto(clusval2); - -// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -// } +TEST(nnw_hierarchy, draw) { + auto hier = std::make_shared(); + bayesmix::NNWPrior prior; + Eigen::Vector2d mu0; + mu0 << 5.5, 5.5; + bayesmix::Vector mu0_proto; + bayesmix::to_proto(mu0, &mu0_proto); + double lambda0 = 0.2; + double nu0 = 5.0; + Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; + bayesmix::Matrix tau0_proto; + bayesmix::to_proto(tau0, &tau0_proto); + *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; + prior.mutable_fixed_values()->set_var_scaling(lambda0); + prior.mutable_fixed_values()->set_deg_free(nu0); + *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; + hier->get_mutable_prior()->CopyFrom(prior); + hier->initialize(); -// TEST(nnwhierarchy, sample_given_data) { -// auto hier = std::make_shared(); -// bayesmix::NNWPrior prior; -// Eigen::Vector2d mu0; -// mu0 << 5.5, 5.5; -// bayesmix::Vector mu0_proto; -// bayesmix::to_proto(mu0, &mu0_proto); -// double lambda0 = 0.2; -// double nu0 = 5.0; -// Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; -// bayesmix::Matrix tau0_proto; -// bayesmix::to_proto(tau0, &tau0_proto); -// *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; -// prior.mutable_fixed_values()->set_var_scaling(lambda0); -// prior.mutable_fixed_values()->set_deg_free(nu0); -// *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; -// hier->get_mutable_prior()->CopyFrom(prior); -// hier->initialize(); - -// Eigen::RowVectorXd datum(2); -// datum << 4.5, 4.5; - -// auto hier2 = hier->clone(); -// hier2->add_datum(0, datum, false); -// hier2->sample_full_cond(); - -// bayesmix::AlgorithmState out; -// bayesmix::AlgorithmState::ClusterState* clusval = -// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 -// = out.add_cluster_states(); hier->write_state_to_proto(clusval); -// hier2->write_state_to_proto(clusval2); - -// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -// } + auto hier2 = hier->clone(); + hier2->sample_prior(); + + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = + out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 + = out.add_cluster_states(); hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); + + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} + +TEST(nnw_hierarchy, sample_given_data) { + auto hier = std::make_shared(); + bayesmix::NNWPrior prior; + Eigen::Vector2d mu0; + mu0 << 5.5, 5.5; + bayesmix::Vector mu0_proto; + bayesmix::to_proto(mu0, &mu0_proto); + double lambda0 = 0.2; + double nu0 = 5.0; + Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; + bayesmix::Matrix tau0_proto; + bayesmix::to_proto(tau0, &tau0_proto); + *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; + prior.mutable_fixed_values()->set_var_scaling(lambda0); + prior.mutable_fixed_values()->set_deg_free(nu0); + *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; + hier->get_mutable_prior()->CopyFrom(prior); + hier->initialize(); + + Eigen::RowVectorXd datum(2); + datum << 4.5, 4.5; + + auto hier2 = hier->clone(); + hier2->add_datum(0, datum, false); + hier2->sample_full_cond(); + + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = + out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 + = out.add_cluster_states(); hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); + + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} // TEST(lin_reg_uni_hierarchy, state_read_write) { // Eigen::Vector2d beta; From 82b1df0d48f1f1b8cc3285b6882e50e38eed7e2f Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Wed, 16 Feb 2022 14:50:24 +0100 Subject: [PATCH 123/317] added files --- .../updaters/target_lpdf_unconstrained.cc | 4 ++ .../updaters/target_lpdf_unconstrained.h | 24 +++++++++ test_mh_updater.cpp | 54 +++++++++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 src/hierarchies/updaters/target_lpdf_unconstrained.cc create mode 100644 src/hierarchies/updaters/target_lpdf_unconstrained.h create mode 100644 test_mh_updater.cpp diff --git a/src/hierarchies/updaters/target_lpdf_unconstrained.cc b/src/hierarchies/updaters/target_lpdf_unconstrained.cc new file mode 100644 index 000000000..fa8fef2ba --- /dev/null +++ b/src/hierarchies/updaters/target_lpdf_unconstrained.cc @@ -0,0 +1,4 @@ +#include "target_lpdf_unconstrained.h" + +// target_lpdf_unconstrained::target_lpdf_unconstrained( +// AbstractHierarchy *p) : parent(p) {} diff --git a/src/hierarchies/updaters/target_lpdf_unconstrained.h b/src/hierarchies/updaters/target_lpdf_unconstrained.h new file mode 100644 index 000000000..9adc8cc02 --- /dev/null +++ b/src/hierarchies/updaters/target_lpdf_unconstrained.h @@ -0,0 +1,24 @@ +#ifndef BAYESMIX_SRC_HIERARCHIES_UPDATERS_TARGET_LPDF_UNCONSTRAINED_H_ +#define BAYESMIX_SRC_HIERARCHIES_UPDATERS_TARGET_LPDF_UNCONSTRAINED_H_ + +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" + +class target_lpdf_unconstrained { + protected: + AbstractLikelihood* like; + AbstractPriorModel* prior; + + public: + target_lpdf_unconstrained(AbstractLikelihood* like, + AbstractPriorModel* prior) + : like(like), prior(prior) {} + + template + T operator()(const Eigen::Matrix& x) const { + return like->cluster_lpdf_from_unconstrained(x) + + prior->lpdf_from_unconstrained(x); + } +}; + +#endif diff --git a/test_mh_updater.cpp b/test_mh_updater.cpp new file mode 100644 index 000000000..c568e7c9b --- /dev/null +++ b/test_mh_updater.cpp @@ -0,0 +1,54 @@ +#include + +#include +#include + +#include "lib/argparse/argparse.h" +#include "src/includes.h" + +int main() { + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + hypers_proto.mutable_nnig_state()->set_mean(0.0); + hypers_proto.mutable_nnig_state()->set_var_scaling(0.1); + hypers_proto.mutable_nnig_state()->set_shape(4.0); + hypers_proto.mutable_nnig_state()->set_scale(3.0); + + bayesmix::NNIGPrior hier_prior; + hier_prior.mutable_fixed_values()->set_mean(0.0); + hier_prior.mutable_fixed_values()->set_var_scaling(0.1); + hier_prior.mutable_fixed_values()->set_shape(4.0); + hier_prior.mutable_fixed_values()->set_scale(3.0); + + auto prior = std::make_shared(); + prior->get_mutable_prior()->CopyFrom(hier_prior); + + // prior->set_hypers_from_proto(hypers_proto); + auto like = std::make_shared(); + auto updater = std::make_shared(0.001); + auto hier = std::make_shared(); + hier->set_likelihood(like); + hier->set_prior(prior); + hier->set_updater(updater); + std::cout << "here" << std::endl; + + hier->initialize(); + std::cout << "initializing" << std::endl; + + auto& rng = bayesmix::Rng::Instance().get(); + int ndata = 250; + Eigen::VectorXd data(ndata); + for (int i = 0; i < ndata; i++) { + data(i) = stan::math::normal_rng(5, 1.0, rng); + hier->add_datum(i, data.row(i)); + } + + int niter = 10000; + Eigen::MatrixXd chain(niter, 2); + for (int i = 0; i < niter; i++) { + hier->sample_full_cond(); + chain.row(i) = hier->get_state().get_unconstrained(); + } + + bayesmix::write_matrix_to_file(chain, "mcmc_chain_test.csv"); +} From 826d9678b5976a2c830d908d8aa18766216ba321 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Wed, 16 Feb 2022 14:55:29 +0100 Subject: [PATCH 124/317] modified gitignore --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index bea0b049f..571e85715 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,8 @@ sftp-config.json *.local.* # MacOS storage files .DS_Store +.dockerignore +.ipynb_checkpoints/ +docs/_build/ +resources/benchmarks/datasets +resources/2d From 151f32b5bbdd4134ef3d0f3aa688a20f9ceae833 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:46:07 +0100 Subject: [PATCH 125/317] Remove double definition --- src/algorithms/neal8_algorithm.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/algorithms/neal8_algorithm.h b/src/algorithms/neal8_algorithm.h index 96d9ca177..0b25ada33 100644 --- a/src/algorithms/neal8_algorithm.h +++ b/src/algorithms/neal8_algorithm.h @@ -43,8 +43,6 @@ class Neal8Algorithm : public Neal2Algorithm { void read_params_from_proto( const bayesmix::AlgorithmParams ¶ms) override; - bool requires_conjugate_hierarchy() const override { return false; } - protected: void initialize() override; From eeb61b270f54972f8e19a1ce73dafb47334072c4 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:48:15 +0100 Subject: [PATCH 126/317] Bug fix + documentation --- src/hierarchies/base_hierarchy.h | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 169f901ec..2ce883874 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -28,18 +28,16 @@ //! @tparam Derived Name of the implemented derived class //! @tparam Likelihood Class name of the likelihood model for the hierarchy //! @tparam PriorModel Class name of the prior model for the hierarchy -//! @tparam Updater Class name for the update algorithm used for posterior sampling template class BaseHierarchy : public AbstractHierarchy { protected: - //! Container for the likelihood of the hierarchy std::shared_ptr like = std::make_shared(); //! Container for the prior model of the hierarchy std::shared_ptr prior = std::make_shared(); - + //! Container for the update algorithm adopted std::shared_ptr updater; @@ -74,13 +72,11 @@ class BaseHierarchy : public AbstractHierarchy { updater = updater_; }; - virtual void set_default_updater() = 0; - std::shared_ptr get_likelihood() override { return like; } std::shared_ptr get_prior() override { return prior; } - + //! Returns an independent, data-less copy of this object std::shared_ptr clone() const override { // Create copy of the hierarchy @@ -91,7 +87,8 @@ class BaseHierarchy : public AbstractHierarchy { return out; }; - // NOT SURE THIS IS CORRECT, MAYBE OVERRIDE GET_LIKE_LPDF? OR THIS IS EVEN UNNECESSARY + // NOT SURE THIS IS CORRECT, MAYBE OVERRIDE GET_LIKE_LPDF? OR THIS IS EVEN + // UNNECESSARY double like_lpdf(const Eigen::RowVectorXd &datum) const override { return like->lpdf(datum); } @@ -255,7 +252,8 @@ class BaseHierarchy : public AbstractHierarchy { like->write_state_to_proto(out); }; - //! Writes current values of the hyperparameters to a Protobuf message by pointer + //! Writes current values of the hyperparameters to a Protobuf message by + //! pointer void write_hypers_to_proto(google::protobuf::Message *out) const override { prior->write_hypers_to_proto(out); }; @@ -339,7 +337,6 @@ class BaseHierarchy : public AbstractHierarchy { //! Returns the struct of the current posterior hyperparameters // Hyperparams get_posterior_hypers() const { return posterior_hypers; } - //! Raises an error if the prior pointer is not initialized // void check_prior_is_set() const { // if (prior == nullptr) { @@ -362,7 +359,6 @@ class BaseHierarchy : public AbstractHierarchy { // virtual std::shared_ptr // get_state_proto() const = 0; - //! Writes current value of hyperparameters to a Protobuf message and //! return a shared_ptr. //! New hierarchies have to first modify the field 'oneof val' in the @@ -407,7 +403,6 @@ class BaseHierarchy : public AbstractHierarchy { // const bayesmix::AlgorithmState::HierarchyHypers &>(state_); // } - // //! Container for prior hyperparameters values // std::shared_ptr hypers; From de02a925e2a08f16933e256c6161043992517532 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:50:07 +0100 Subject: [PATCH 127/317] Updater is no more template --- src/hierarchies/nnig_hierarchy.h | 4 +--- src/hierarchies/nnxig_hierarchy.h | 9 +++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index fa797c5ac..d6f13febc 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -29,9 +29,7 @@ class NNIGHierarchy return bayesmix::HierarchyId::NNIG; } - void set_default_updater() override { - updater = std::make_shared(); - } + void set_default_updater() { updater = std::make_shared(); } void initialize_state() override { // Get hypers diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h index 9d5e7ff9f..a99a27909 100644 --- a/src/hierarchies/nnxig_hierarchy.h +++ b/src/hierarchies/nnxig_hierarchy.h @@ -17,8 +17,8 @@ #include "priors/nxig_prior_model.h" #include "updaters/nnxig_updater.h" -class NNxIGHierarchy : public BaseHierarchy { +class NNxIGHierarchy + : public BaseHierarchy { public: NNxIGHierarchy() = default; ~NNxIGHierarchy() = default; @@ -27,6 +27,8 @@ class NNxIGHierarchy : public BaseHierarchy(); } + void initialize_state() override { // Get hypers auto hypers = prior->get_hypers(); @@ -36,7 +38,6 @@ class NNxIGHierarchy : public BaseHierarchyset_state(state); }; - }; -#endif // BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ +#endif // BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ From a2f4352eec3936ea05a66bda1b2898f91f388a1d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:52:11 +0100 Subject: [PATCH 128/317] Promoted state for AD and split in files --- .../likelihoods/{ => states/.old}/states.h | 0 src/hierarchies/likelihoods/states/includes.h | 7 ++ .../likelihoods/states/multi_ls_state.h | 93 +++++++++++++++++++ .../likelihoods/states/uni_ls_state.h | 72 ++++++++++++++ 4 files changed, 172 insertions(+) rename src/hierarchies/likelihoods/{ => states/.old}/states.h (100%) create mode 100644 src/hierarchies/likelihoods/states/includes.h create mode 100644 src/hierarchies/likelihoods/states/multi_ls_state.h create mode 100644 src/hierarchies/likelihoods/states/uni_ls_state.h diff --git a/src/hierarchies/likelihoods/states.h b/src/hierarchies/likelihoods/states/.old/states.h similarity index 100% rename from src/hierarchies/likelihoods/states.h rename to src/hierarchies/likelihoods/states/.old/states.h diff --git a/src/hierarchies/likelihoods/states/includes.h b/src/hierarchies/likelihoods/states/includes.h new file mode 100644 index 000000000..f811dd565 --- /dev/null +++ b/src/hierarchies/likelihoods/states/includes.h @@ -0,0 +1,7 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ + +#include "multi_ls_state.h" +#include "uni_ls_state.h" + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h new file mode 100644 index 000000000..ae35d7eb8 --- /dev/null +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -0,0 +1,93 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "src/utils/proto_utils.h" + +namespace State { + +template +Eigen::Matrix multi_ls_to_unconstrained( + Eigen::Matrix mean_in, + Eigen::Matrix prec_in) { + Eigen::Matrix prec_out = + stan::math::cov_matrix_free(prec_in); + Eigen::Matrix out(mean_in.size() + prec_out.size()); + out << mean_in, prec_out; + return out; +} + +template +std::tuple, + Eigen::Matrix> +multi_ls_to_constrained(Eigen::Matrix in) { + double dim_ = 0.5 * (std::sqrt(8 * in.size() + 9) - 3); + double dimf; + assert(modf(dim_, &dimf) == 0.0); + int dim = int(dimf); + Eigen::Matrix mean(dim); + mean << in.head(dim); + Eigen::Matrix prec(dim, dim); + prec = stan::math::cov_matrix_constrain(in.tail(in.size() - dim), dim); + return std::make_tuple(mean, prec); +} + +template +T multi_ls_log_det_jac( + Eigen::Matrix prec_constrained) { + T out = 0; + stan::math::positive_constrain(stan::math::cov_matrix_free(prec_constrained), + out); + return out; +} + +class MultiLS { + public: + Eigen::VectorXd mean; + Eigen::MatrixXd prec, prec_chol; + double prec_logdet; + + Eigen::VectorXd get_unconstrained() { + return multi_ls_to_unconstrained(mean, prec); + } + + void set_from_unconstrained(Eigen::VectorXd in) { + std::tie(mean, prec) = multi_ls_to_constrained(in); + set_from_constrained(mean, prec); + } + + void set_from_constrained(Eigen::VectorXd mean_, Eigen::MatrixXd prec_) { + mean = mean_; + prec = prec_; + prec_chol = Eigen::LLT(prec).matrixL(); + Eigen::VectorXd diag = prec_chol.diagonal(); + prec_logdet = 2 * log(diag.array()).sum(); + } + + void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + mean = to_eigen(state_.multi_ls_state().mean()); + prec = to_eigen(state_.multi_ls_state().prec()); + prec_chol = to_eigen(state_.multi_ls_state().prec_chol()); + Eigen::VectorXd diag = prec_chol.diagonal(); + prec_logdet = 2 * log(diag.array()).sum(); + } + + bayesmix::AlgorithmState::ClusterState get_as_proto() { + bayesmix::AlgorithmState::ClusterState state; + bayesmix::to_proto(mean, state.mutable_multi_ls_state()->mutable_mean()); + bayesmix::to_proto(prec, state.mutable_multi_ls_state()->mutable_prec()); + bayesmix::to_proto(prec_chol, + state.mutable_multi_ls_state()->mutable_prec_chol()); + return state; + } + + double log_det_jac() { return multi_ls_log_det_jac(prec); } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h new file mode 100644 index 000000000..c7b242ba5 --- /dev/null +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -0,0 +1,72 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "src/utils/proto_utils.h" + +namespace State { + +template +Eigen::Matrix uni_ls_to_constrained( + Eigen::Matrix in) { + Eigen::Matrix out(2); + out << in(0), stan::math::exp(in(1)); + return out; +} + +template +Eigen::Matrix uni_ls_to_unconstrained( + Eigen::Matrix in) { + Eigen::Matrix out(2); + out << in(0), stan::math::log(in(1)); + return out; +} + +template +T uni_ls_log_det_jac(Eigen::Matrix constrained) { + T out = 0; + stan::math::positive_constrain(stan::math::log(constrained(1)), out); + return out; +} + +class UniLS { + public: + double mean, var; + + Eigen::VectorXd get_unconstrained() { + Eigen::VectorXd temp(2); + temp << mean, var; + return uni_ls_to_unconstrained(temp); + } + + void set_from_unconstrained(Eigen::VectorXd in) { + Eigen::VectorXd temp = uni_ls_to_constrained(in); + mean = temp(0); + var = temp(1); + } + + void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + mean = state_.uni_ls_state().mean(); + var = state_.uni_ls_state().var(); + } + + bayesmix::AlgorithmState::ClusterState get_as_proto() { + bayesmix::AlgorithmState::ClusterState state; + state.mutable_uni_ls_state()->set_mean(mean); + state.mutable_uni_ls_state()->set_var(var); + return state; + } + + double log_det_jac() { + Eigen::VectorXd temp(2); + temp << mean, var; + return uni_ls_log_det_jac(temp); + } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ From ca0e07bde6165b7abdf8bd9ca51d39d7bf5db0c2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:52:50 +0100 Subject: [PATCH 129/317] Update includes --- src/hierarchies/priors/abstract_prior_model.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index bf7354c36..aec5bc15e 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -8,7 +8,7 @@ #include #include "algorithm_state.pb.h" -#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/likelihoods/states/includes.h" #include "src/utils/rng.h" class AbstractPriorModel { From 7f7a73380a8b6001953b735821b230ae70464e07 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:54:38 +0100 Subject: [PATCH 130/317] Updater no more tparam --- src/hierarchies/nnw_hierarchy.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index dab6fa00a..12d1c543e 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -18,8 +18,8 @@ #include "priors/nw_prior_model.h" #include "updaters/nnw_updater.h" -class NNWHierarchy : public BaseHierarchy { +class NNWHierarchy + : public BaseHierarchy { public: NNWHierarchy() = default; ~NNWHierarchy() = default; @@ -28,6 +28,8 @@ class NNWHierarchy : public BaseHierarchy(); } + void initialize_state() override { // Initialize likelihood dimension to prior one like->set_dim(prior->get_dim()); @@ -37,7 +39,8 @@ class NNWHierarchy : public BaseHierarchywrite_prec_to_state(hypers.var_scaling * Eigen::MatrixXd::Identity(dim, dim), &state); + prior->write_prec_to_state( + hypers.var_scaling * Eigen::MatrixXd::Identity(dim, dim), &state); like->set_state(state); }; @@ -66,4 +69,4 @@ class NNWHierarchy : public BaseHierarchy Date: Thu, 17 Feb 2022 11:55:36 +0100 Subject: [PATCH 131/317] Update target sources --- src/hierarchies/likelihoods/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 1fe527150..4a642dd9d 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -1,9 +1,10 @@ target_sources(bayesmix PUBLIC abstract_likelihood.h base_likelihood.h - states.h uni_norm_likelihood.h uni_norm_likelihood.cc multi_norm_likelihood.h multi_norm_likelihood.cc ) + +add_subdirectory(states) From ada4b285ed1e4a99ea990133c02dbffb71163711 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:56:14 +0100 Subject: [PATCH 132/317] Add exeption handling test --- test/hierarchies.cc | 45 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/test/hierarchies.cc b/test/hierarchies.cc index 006abae5b..97467f141 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -92,9 +92,9 @@ TEST(nnw_hierarchy, draw) { hier2->sample_prior(); bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = - out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 - = out.add_cluster_states(); hier->write_state_to_proto(clusval); + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); hier2->write_state_to_proto(clusval2); ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); @@ -127,14 +127,41 @@ TEST(nnw_hierarchy, sample_given_data) { hier2->sample_full_cond(); bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = - out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 - = out.add_cluster_states(); hier->write_state_to_proto(clusval); + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); hier2->write_state_to_proto(clusval2); ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } +TEST(nnw_hierarchy, no_unconstrained_lpdf) { + // Initialize hierarchy + auto hier = std::make_shared(); + bayesmix::NNWPrior prior; + Eigen::Vector2d mu0; + mu0 << 5.5, 5.5; + bayesmix::Vector mu0_proto; + bayesmix::to_proto(mu0, &mu0_proto); + double lambda0 = 0.2; + double nu0 = 5.0; + Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; + bayesmix::Matrix tau0_proto; + bayesmix::to_proto(tau0, &tau0_proto); + *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; + prior.mutable_fixed_values()->set_var_scaling(lambda0); + prior.mutable_fixed_values()->set_deg_free(nu0); + *prior.mutable_fixed_values()->mutable_scale() = tau0_proto; + hier->get_mutable_prior()->CopyFrom(prior); + hier->initialize(); + + // Check exeption handling in case unconstrained lpdfs are not implemented + auto state_uc = hier->get_state().get_unconstrained(); + EXPECT_ANY_THROW( + hier->get_likelihood()->cluster_lpdf_from_unconstrained(state_uc)); + EXPECT_ANY_THROW(hier->get_prior()->lpdf_from_unconstrained(state_uc)); +} + // TEST(lin_reg_uni_hierarchy, state_read_write) { // Eigen::Vector2d beta; // beta << 2, -1; @@ -237,9 +264,9 @@ TEST(nnxig_hierarchy, draw) { hier2->sample_prior(); bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = - out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 - = out.add_cluster_states(); hier->write_state_to_proto(clusval); + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); hier2->write_state_to_proto(clusval2); ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); From 56dec58ef1db273c58a0dafff5e1de982b6c1f53 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:56:48 +0100 Subject: [PATCH 133/317] Cleaned code --- src/hierarchies/likelihoods/abstract_likelihood.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index c2f5b58c6..ca7ef4403 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -4,15 +4,10 @@ #include #include +#include #include -// #include -// #include -// #include - #include "algorithm_state.pb.h" -// #include "hierarchy_id.pb.h" -// #include "src/utils/rng.h" class AbstractLikelihood { public: From 2ae45d07112c585d7295847183a88b869d5aa7c1 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:58:45 +0100 Subject: [PATCH 134/317] SFINAE Exeption handling for unconstrained lpdf --- src/hierarchies/likelihoods/base_likelihood.h | 53 +++++++++++++++---- src/hierarchies/priors/base_prior_model.h | 43 +++++++++++++-- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 2c7b7993b..c2072779a 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -4,14 +4,35 @@ #include #include -#include -// #include #include -// #include +#include +#include #include "abstract_likelihood.h" #include "algorithm_state.pb.h" +namespace internal { + +template +auto cluster_lpdf_from_unconstrained( + const Like &like, Eigen::Matrix unconstrained_params, + int) + -> decltype(like.template cluster_lpdf_from_unconstrained( + unconstrained_params)) { + return like.template cluster_lpdf_from_unconstrained( + unconstrained_params); +} + +template +auto cluster_lpdf_from_unconstrained( + const Like &like, Eigen::Matrix unconstrained_params, + double) -> T { + throw(std::runtime_error( + "cluster_lpdf_from_unconstrained() not yet implemented")); +} + +} // namespace internal + template class BaseLikelihood : public AbstractLikelihood { public: @@ -26,19 +47,33 @@ class BaseLikelihood : public AbstractLikelihood { } // The unconstrained parameters are mean and log(var) + + // double cluster_lpdf_from_unconstrained( + // Eigen::VectorXd unconstrained_params) const override { + // return static_cast(*this) + // .template cluster_lpdf_from_unconstrained( + // unconstrained_params); + // } + + // stan::math::var cluster_lpdf_from_unconstrained( + // Eigen::Matrix + // unconstrained_params) const override { + // return static_cast(*this) + // .template cluster_lpdf_from_unconstrained( + // unconstrained_params); + // } + double cluster_lpdf_from_unconstrained( Eigen::VectorXd unconstrained_params) const override { - return static_cast(*this) - .template cluster_lpdf_from_unconstrained( - unconstrained_params); + return internal::cluster_lpdf_from_unconstrained( + static_cast(*this), unconstrained_params, 0); } stan::math::var cluster_lpdf_from_unconstrained( Eigen::Matrix unconstrained_params) const override { - return static_cast(*this) - .template cluster_lpdf_from_unconstrained( - unconstrained_params); + return internal::cluster_lpdf_from_unconstrained( + static_cast(*this), unconstrained_params, 0); } virtual Eigen::VectorXd lpdf_grid(const Eigen::MatrixXd &data, diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 506f20796..0608bc40a 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -14,6 +14,26 @@ #include "hierarchy_id.pb.h" #include "src/utils/rng.h" +namespace internal { + +template +auto lpdf_from_unconstrained( + const Prior &prior, + Eigen::Matrix unconstrained_params, int) + -> decltype(prior.template lpdf_from_unconstrained( + unconstrained_params)) { + return prior.template lpdf_from_unconstrained(unconstrained_params); +} + +template +auto lpdf_from_unconstrained( + const Prior &prior, + Eigen::Matrix unconstrained_params, double) -> T { + throw(std::runtime_error("lpdf_from_unconstrained() not yet implemented")); +} + +} // namespace internal + template class BasePriorModel : public AbstractPriorModel { public: @@ -23,18 +43,31 @@ class BasePriorModel : public AbstractPriorModel { double lpdf_from_unconstrained( Eigen::VectorXd unconstrained_params) const override { - return static_cast(*this) - .template lpdf_from_unconstrained(unconstrained_params); + return internal::lpdf_from_unconstrained( + static_cast(*this), unconstrained_params, 0); } stan::math::var lpdf_from_unconstrained( Eigen::Matrix unconstrained_params) const override { - return static_cast(*this) - .template lpdf_from_unconstrained( - unconstrained_params); + return internal::lpdf_from_unconstrained( + static_cast(*this), unconstrained_params, 0); } + // double lpdf_from_unconstrained( + // Eigen::VectorXd unconstrained_params) const override { + // return static_cast(*this) + // .template lpdf_from_unconstrained(unconstrained_params); + // } + + // stan::math::var lpdf_from_unconstrained( + // Eigen::Matrix + // unconstrained_params) const override { + // return static_cast(*this) + // .template lpdf_from_unconstrained( + // unconstrained_params); + // } + virtual std::shared_ptr clone() const override; virtual google::protobuf::Message *get_mutable_prior() override; From 3f0a8dbf031806da96055c1d4bc7ed1ad92fce09 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 11:59:11 +0100 Subject: [PATCH 135/317] Update includes --- test/distributions.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributions.cc b/test/distributions.cc index 1e1468b15..612f3707f 100644 --- a/test/distributions.cc +++ b/test/distributions.cc @@ -6,7 +6,7 @@ #include #include -#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/likelihoods/states/includes.h" #include "src/utils/rng.h" TEST(mix_dist, 1) { From 11af1db4a49a32d2ca97ead5a8ff44f22232721a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 12:00:24 +0100 Subject: [PATCH 136/317] Update includes --- src/hierarchies/priors/nw_prior_model.h | 6 +++--- src/hierarchies/priors/nxig_prior_model.h | 5 +++-- src/hierarchies/updaters/nnig_updater.cc | 4 ++-- src/hierarchies/updaters/nnw_updater.cc | 7 ++++--- src/hierarchies/updaters/nnxig_updater.cc | 7 ++----- 5 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h index e2dc2482d..9fd349bc1 100644 --- a/src/hierarchies/priors/nw_prior_model.h +++ b/src/hierarchies/priors/nw_prior_model.h @@ -3,9 +3,10 @@ // #include -#include +// #include #include #include +#include #include // #include "algorithm_state.pb.h" @@ -15,7 +16,7 @@ #include "src/utils/rng.h" class NWPriorModel : public BasePriorModel { + bayesmix::NNWPrior> { public: NWPriorModel() = default; ~NWPriorModel() = default; @@ -36,7 +37,6 @@ class NWPriorModel : public BasePriorModel get_hypers_proto() const override; diff --git a/src/hierarchies/priors/nxig_prior_model.h b/src/hierarchies/priors/nxig_prior_model.h index a7282187a..85ade8612 100644 --- a/src/hierarchies/priors/nxig_prior_model.h +++ b/src/hierarchies/priors/nxig_prior_model.h @@ -3,9 +3,10 @@ // #include -#include +// #include #include #include +#include #include // #include "algorithm_state.pb.h" @@ -15,7 +16,7 @@ #include "src/utils/rng.h" class NxIGPriorModel : public BasePriorModel { + bayesmix::NNxIGPrior> { public: NxIGPriorModel() = default; ~NxIGPriorModel() = default; diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index 318ba0908..9bd709877 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -1,6 +1,6 @@ #include "nnig_updater.h" -#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/likelihoods/states/includes.h" #include "src/hierarchies/priors/hyperparams.h" void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, @@ -8,7 +8,7 @@ void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); auto& priorcast = downcast_prior(prior); - + // Getting required quantities from likelihood and prior int card = likecast.get_card(); double data_sum = likecast.get_data_sum(); diff --git a/src/hierarchies/updaters/nnw_updater.cc b/src/hierarchies/updaters/nnw_updater.cc index 5c6807ab1..ad43fe6fc 100644 --- a/src/hierarchies/updaters/nnw_updater.cc +++ b/src/hierarchies/updaters/nnw_updater.cc @@ -1,6 +1,6 @@ #include "nnw_updater.h" -#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/likelihoods/states/includes.h" #include "src/hierarchies/priors/hyperparams.h" void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, @@ -8,7 +8,7 @@ void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); auto& priorcast = downcast_prior(prior); - + // Getting required quantities from likelihood and prior int card = likecast.get_card(); Eigen::VectorXd data_sum = likecast.get_data_sum(); @@ -35,7 +35,8 @@ void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, (mubar - hypers.mean) * (mubar - hypers.mean).transpose(); post_params.scale_inv = tau_temp + hypers.scale_inv; post_params.scale = stan::math::inverse_spd(post_params.scale_inv); - post_params.scale_chol = Eigen::LLT(post_params.scale).matrixU(); + post_params.scale_chol = + Eigen::LLT(post_params.scale).matrixU(); priorcast.set_posterior_hypers(post_params); return; }; diff --git a/src/hierarchies/updaters/nnxig_updater.cc b/src/hierarchies/updaters/nnxig_updater.cc index 3c7064081..c94c8ba3a 100644 --- a/src/hierarchies/updaters/nnxig_updater.cc +++ b/src/hierarchies/updaters/nnxig_updater.cc @@ -1,6 +1,6 @@ #include "nnxig_updater.h" -#include "src/hierarchies/likelihoods/states.h" +#include "src/hierarchies/likelihoods/states/includes.h" #include "src/hierarchies/priors/hyperparams.h" void NNxIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, @@ -27,12 +27,9 @@ void NNxIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, card * state.mean * state.mean; post_params.mean = (hypers.var * data_sum + state.var * hypers.mean) / (card * hypers.var + state.var); - post_params.var = - (state.var * hypers.var) / (card * hypers.var + state.var); + post_params.var = (state.var * hypers.var) / (card * hypers.var + state.var); post_params.shape = hypers.shape + 0.5 * card; post_params.scale = hypers.scale + 0.5 * var_y; priorcast.set_posterior_hypers(post_params); return; }; - - From c8c93139e5442aa3e56a51553c25d50eb51db48b Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 12:01:09 +0100 Subject: [PATCH 137/317] Improved code --- src/hierarchies/likelihoods/multi_norm_likelihood.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index f61696ae8..4bd936f3f 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -3,14 +3,14 @@ #include -#include #include #include +#include #include #include "algorithm_state.pb.h" #include "base_likelihood.h" -#include "states.h" +#include "states/includes.h" class MultiNormLikelihood : public BaseLikelihood { @@ -26,7 +26,7 @@ class MultiNormLikelihood void set_dim(unsigned int dim_) { dim = dim_; data_sum = Eigen::VectorXd::Zero(dim); - data_sum_squares = Eigen::MatrixXd::Zero(dim,dim); + data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); }; unsigned int get_dim() const { return dim; }; Eigen::VectorXd get_data_sum() const { return data_sum; }; From 2d97dcaa9ceb709c849cc97dbc57b712af0dc5e0 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 12:01:34 +0100 Subject: [PATCH 138/317] Improved code --- .../likelihoods/multi_norm_likelihood.cc | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.cc b/src/hierarchies/likelihoods/multi_norm_likelihood.cc index f187bed6e..8c12942aa 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.cc @@ -4,35 +4,36 @@ #include "src/utils/eigen_utils.h" #include "src/utils/proto_utils.h" -double MultiNormLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { +double MultiNormLikelihood::compute_lpdf( + const Eigen::RowVectorXd &datum) const { return bayesmix::multi_normal_prec_lpdf(datum, state.mean, state.prec_chol, state.prec_logdet); } -void MultiNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, bool add) { - - // Prepare buffer in case dim is not defined yet - if (!dim) - set_dim(datum.size()); - // Updates - if (add) { - data_sum += datum.transpose(); - data_sum_squares += datum.transpose() * datum; - } else { - data_sum -= datum.transpose(); - data_sum_squares -= datum.transpose() * datum; - } +void MultiNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + bool add) { + // Check if dim is not defined yet (usually not happens if hierarchy is + // initialized) + if (!dim) set_dim(datum.size()); + // Updates + if (add) { + data_sum += datum.transpose(); + data_sum_squares += datum.transpose() * datum; + } else { + data_sum -= datum.transpose(); + data_sum_squares -= datum.transpose() * datum; + } } void MultiNormLikelihood::set_state_from_proto( const google::protobuf::Message &state_, bool update_card) { - auto &statecast = downcast_state(state_); - state.mean = to_eigen(statecast.multi_ls_state().mean()); - state.prec = to_eigen(statecast.multi_ls_state().prec()); - state.prec_chol = to_eigen(statecast.multi_ls_state().prec_chol()); - Eigen::VectorXd diag = state.prec_chol.diagonal(); - state.prec_logdet = 2 * log(diag.array()).sum(); - if (update_card) set_card(statecast.cardinality()); + auto &statecast = downcast_state(state_); + state.mean = to_eigen(statecast.multi_ls_state().mean()); + state.prec = to_eigen(statecast.multi_ls_state().prec()); + state.prec_chol = to_eigen(statecast.multi_ls_state().prec_chol()); + Eigen::VectorXd diag = state.prec_chol.diagonal(); + state.prec_logdet = 2 * log(diag.array()).sum(); + if (update_card) set_card(statecast.cardinality()); } std::shared_ptr @@ -47,6 +48,6 @@ MultiNormLikelihood::get_state_proto() const { } void MultiNormLikelihood::clear_summary_statistics() { - data_sum = Eigen::VectorXd::Zero(dim); - data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); + data_sum = Eigen::VectorXd::Zero(dim); + data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); } From 958f04adcf6be7d8bdd45536eeb0091a6b3c2e80 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 12:01:57 +0100 Subject: [PATCH 139/317] Update includes --- src/hierarchies/likelihoods/uni_norm_likelihood.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 9947eb53e..1579a6191 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -9,7 +9,7 @@ #include "algorithm_state.pb.h" #include "base_likelihood.h" -#include "states.h" +#include "states/includes.h" class UniNormLikelihood : public BaseLikelihood { From d4d13628b06dcff5ad0f30d6e86eb2b27223b4a8 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Thu, 17 Feb 2022 12:02:36 +0100 Subject: [PATCH 140/317] Add target sources --- src/hierarchies/likelihoods/states/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/hierarchies/likelihoods/states/CMakeLists.txt diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt new file mode 100644 index 000000000..593bef3c5 --- /dev/null +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources(bayesmix PUBLIC + includes.h + uni_ls_state.h + multi_ls_state.h +) From d87d243cd72010de16b66f233eddd32c65f7f111 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 09:17:10 +0100 Subject: [PATCH 141/317] Add uni_lin_reg_state (ONGOING) --- .../likelihoods/states/CMakeLists.txt | 3 +- src/hierarchies/likelihoods/states/includes.h | 1 + .../likelihoods/states/uni_lin_reg_state.h | 79 +++++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 src/hierarchies/likelihoods/states/uni_lin_reg_state.h diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt index 593bef3c5..fa3d87c32 100644 --- a/src/hierarchies/likelihoods/states/CMakeLists.txt +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -1,5 +1,6 @@ target_sources(bayesmix PUBLIC - includes.h uni_ls_state.h multi_ls_state.h + uni_lin_reg_state.h + includes.h ) diff --git a/src/hierarchies/likelihoods/states/includes.h b/src/hierarchies/likelihoods/states/includes.h index f811dd565..d0c921f6f 100644 --- a/src/hierarchies/likelihoods/states/includes.h +++ b/src/hierarchies/likelihoods/states/includes.h @@ -2,6 +2,7 @@ #define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ #include "multi_ls_state.h" +#include "uni_lin_reg_state.h" #include "uni_ls_state.h" #endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_state.h new file mode 100644 index 000000000..b7e889afa --- /dev/null +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_state.h @@ -0,0 +1,79 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_STATE_H_ + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "src/utils/proto_utils.h" + +// TODO: CHECK VECTOR ASSIGNMENTS AND POSITIONING! +namespace State { + +template +Eigen::Matrix uni_lin_reg_to_constrained( + Eigen::Matrix in) { + int N = in.size(); + Eigen::Matrix out(N); + out << in.head(N - 1), stan::math::exp(in(N - 1)); + return out; +} + +template +Eigen::Matrix uni_lin_reg_to_unconstrained( + Eigen::Matrix in) { + int N = in.size(); + Eigen::Matrix out(N); + out << in.head(N - 1), stan::math::log(in(N - 1)); + return out; +} + +template +T uni_lin_reg_log_det_jac(Eigen::Matrix constrained) { + T out = 0; + int N = constrained.size(); + stan::math::positive_constrain(stan::math::log(constrained(N - 1)), out); + return out; +} + +class UniLinReg { + public: + Eigen::VectorXd regression_coeffs; + double var; + + Eigen::VectorXd get_unconstrained() { + Eigen::VectorXd temp(regression_coeffs.size() + 1); + temp << regression_coeffs, var; + return uni_lin_reg_to_unconstrained(temp); + } + + void set_from_unconstrained(Eigen::VectorXd in) { + Eigen::VectorXd temp = uni_lin_reg_to_constrained(in); + int dim = in.size() - 1; + regression_coeffs = temp.head(dim); + var = temp(dim); + } + + void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + mean = state_.uni_ls_state().mean(); + var = state_.uni_ls_state().var(); + } + + // bayesmix::AlgorithmState::ClusterState get_as_proto() { + // bayesmix::AlgorithmState::ClusterState state; + // state.mutable_uni_ls_state()->set_mean(mean); + // state.mutable_uni_ls_state()->set_var(var); + // return state; + // } + + double log_det_jac() { + Eigen::VectorXd temp(regression_coeffs.size() + 1); + temp << regression_coeffs, var; + return uni_lin_reg_log_det_jac(temp); + } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_STATE_H_ From 5aa234ac8bd4dd29ba9c19d00f95161b0db525db Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 11:19:10 +0100 Subject: [PATCH 142/317] get_state_proto() now public --- src/hierarchies/likelihoods/multi_norm_likelihood.h | 3 ++- src/hierarchies/likelihoods/states/uni_lin_reg_state.h | 9 +++++---- src/hierarchies/likelihoods/uni_norm_likelihood.h | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index 4bd936f3f..f936a4c5a 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -32,9 +32,10 @@ class MultiNormLikelihood Eigen::VectorXd get_data_sum() const { return data_sum; }; Eigen::MatrixXd get_data_sum_squares() const { return data_sum_squares; }; - protected: std::shared_ptr get_state_proto() const override; + + protected: double compute_lpdf(const Eigen::RowVectorXd &datum) const override; void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_state.h index b7e889afa..a97bfa1c0 100644 --- a/src/hierarchies/likelihoods/states/uni_lin_reg_state.h +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_state.h @@ -55,10 +55,11 @@ class UniLinReg { var = temp(dim); } - void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { - mean = state_.uni_ls_state().mean(); - var = state_.uni_ls_state().var(); - } + // void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) + // { + // mean = state_.uni_ls_state().mean(); + // var = state_.uni_ls_state().var(); + // } // bayesmix::AlgorithmState::ClusterState get_as_proto() { // bayesmix::AlgorithmState::ClusterState state; diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 1579a6191..02de25ba5 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -36,9 +36,10 @@ class UniNormLikelihood return out; } - protected: std::shared_ptr get_state_proto() const override; + + protected: double compute_lpdf(const Eigen::RowVectorXd &datum) const override; void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; From 80fad753432f31977737ed2021f3f6162d517088 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 11:20:31 +0100 Subject: [PATCH 143/317] Fix includes --- src/algorithms/split_and_merge_algorithm.h | 1 + src/mixings/mixture_finite_mixing.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/algorithms/split_and_merge_algorithm.h b/src/algorithms/split_and_merge_algorithm.h index 43d595c53..e19cb9525 100644 --- a/src/algorithms/split_and_merge_algorithm.h +++ b/src/algorithms/split_and_merge_algorithm.h @@ -3,6 +3,7 @@ #include #include +#include #include "algorithm_id.pb.h" #include "marginal_algorithm.h" diff --git a/src/mixings/mixture_finite_mixing.h b/src/mixings/mixture_finite_mixing.h index 9813f986a..63de8f81e 100644 --- a/src/mixings/mixture_finite_mixing.h +++ b/src/mixings/mixture_finite_mixing.h @@ -5,6 +5,7 @@ #include #include +#include #include #include "base_mixing.h" From e10e3551dc3a905ec9d12304c16a44c51e923ce4 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 11:20:55 +0100 Subject: [PATCH 144/317] Add documentation --- src/hierarchies/abstract_hierarchy.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index e544780c4..8d7aa7485 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -63,6 +63,7 @@ class AbstractHierarchy { //! Returns an independent, data-less copy of this object virtual std::shared_ptr clone() const = 0; + //! Returns an independent, data-less copy of this object virtual std::shared_ptr deep_clone() const = 0; // EVALUATION FUNCTIONS FOR SINGLE POINTS From 2929a995157acc8aaf94e7209cee1f9df78c9bcf Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 11:21:52 +0100 Subject: [PATCH 145/317] deep_clone() redefined --- src/hierarchies/base_hierarchy.h | 50 +++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index fef41e258..d8c0b9a26 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -87,6 +87,16 @@ class BaseHierarchy : public AbstractHierarchy { return out; }; + std::shared_ptr deep_clone() const override { + // Create copy of the hierarchy + auto out = std::make_shared(static_cast(*this)); + // Simple Clone is enough for Likelihood + out->set_likelihood(std::static_pointer_cast(like->clone())); + // Deep-Clone required for PriorModel + out->set_prior(std::static_pointer_cast(prior->deep_clone())); + return out; + } + // NOT SURE THIS IS CORRECT, MAYBE OVERRIDE GET_LIKE_LPDF? OR THIS IS EVEN // UNNECESSARY double like_lpdf(const Eigen::RowVectorXd &datum) const override { @@ -94,23 +104,24 @@ class BaseHierarchy : public AbstractHierarchy { } //! Returns an independent, data-less copy of this object - std::shared_ptr deep_clone() const override { - auto out = std::make_shared(static_cast(*this)); + // std::shared_ptr deep_clone() const override { + // auto out = std::make_shared(static_cast(*this)); - out->clear_data(); - out->clear_summary_statistics(); + // out->clear_data(); + // out->clear_summary_statistics(); - out->create_empty_prior(); - std::shared_ptr new_prior(prior->New()); - new_prior->CopyFrom(*prior.get()); - out->get_mutable_prior()->CopyFrom(*new_prior.get()); + // out->create_empty_prior(); + // std::shared_ptr new_prior(prior->New()); + // new_prior->CopyFrom(*prior.get()); + // out->get_mutable_prior()->CopyFrom(*new_prior.get()); - out->create_empty_hypers(); - auto curr_hypers_proto = get_hypers_proto(); - out->set_hypers_from_proto(*curr_hypers_proto.get()); - out->initialize(); - return out; - } + // out->create_empty_hypers(); + // auto curr_hypers_proto = get_hypers_proto(); + // out->set_hypers_from_proto(*curr_hypers_proto.get()); + // out->initialize(); + // return out; + // } //! Evaluates the log-likelihood of data in a grid of points //! @param data Grid of points (by row) which are to be evaluated @@ -261,6 +272,14 @@ class BaseHierarchy : public AbstractHierarchy { //! Returns the indexes of data points belonging to this cluster std::set get_data_idx() const override { return like->get_data_idx(); }; + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type + std::shared_ptr get_state_proto() + const override { + return like->get_state_proto(); + } + //! Returns a pointer to the Protobuf message of the prior of this cluster google::protobuf::Message *get_mutable_prior() override { return prior->get_mutable_prior(); @@ -350,6 +369,9 @@ class BaseHierarchy : public AbstractHierarchy { throw std::runtime_error("marg_lpdf() not yet implemented"); } } + + // TEMPORANEO! + const Eigen::MatrixXd *dataset_ptr; }; // TODO: Move definitions outside the class to improve code cleaness From c0161fef9b24dfd4ad0f231ed4e4fdbbdee7547c Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 11:22:29 +0100 Subject: [PATCH 146/317] deep_clone() method added (ONGOING) --- src/hierarchies/priors/abstract_prior_model.h | 3 +++ src/hierarchies/priors/base_prior_model.h | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index aec5bc15e..74f7c03af 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -18,6 +18,9 @@ class AbstractPriorModel { // IMPLEMENTED in BasePriorModel virtual std::shared_ptr clone() const = 0; + // IMPLEMENTED in BasePriorModel + virtual std::shared_ptr deep_clone() const = 0; + virtual double lpdf(const google::protobuf::Message &state_) = 0; //! Evaluates the log likelihood for unconstrained parameter values. diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 0608bc40a..b231af091 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -70,6 +70,8 @@ class BasePriorModel : public AbstractPriorModel { virtual std::shared_ptr clone() const override; + virtual std::shared_ptr deep_clone() const override; + virtual google::protobuf::Message *get_mutable_prior() override; HyperParams get_hypers() const { return hypers; } @@ -121,6 +123,22 @@ BasePriorModel::clone() const { return out; } +template +std::shared_ptr +BasePriorModel::deep_clone() const { + auto out = std::make_shared(static_cast(*this)); + + // Prior Deep-clone + out->create_empty_prior(); + std::shared_ptr new_prior(prior->New()); + new_prior->CopyFrom(*prior.get()); + out->get_mutable_prior()->CopyFrom(*new_prior.get()); + + // C'è da farlo anche su Hypers, ma va cambiata della roba! + + return out; +} + template google::protobuf::Message * BasePriorModel::get_mutable_prior() { From a1a1c115cde5a40bdaa4036e5bb004c185b7bce3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 11:23:06 +0100 Subject: [PATCH 147/317] Silenced tests for FAHierarchy --- test/hierarchies.cc | 216 ++++++++++++++++++++++---------------------- 1 file changed, 109 insertions(+), 107 deletions(-) diff --git a/test/hierarchies.cc b/test/hierarchies.cc index 28eca7daf..512c10847 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -300,112 +300,114 @@ TEST(nnxig_hierarchy, sample_given_data) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -TEST(fahierarchy, draw) { - auto hier = std::make_shared(); - bayesmix::FAPrior prior; - Eigen::VectorXd mutilde(4); - mutilde << 3.0, 3.0, 4.0, 1.0; - bayesmix::Vector mutilde_proto; - bayesmix::to_proto(mutilde, &mutilde_proto); - int q = 2; - double phi = 1.0; - double alpha0 = 5.0; - Eigen::VectorXd beta(4); - beta << 3.0, 3.0, 2.0, 2.1; - bayesmix::Vector beta_proto; - bayesmix::to_proto(beta, &beta_proto); - *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; - prior.mutable_fixed_values()->set_phi(phi); - prior.mutable_fixed_values()->set_alpha0(alpha0); - prior.mutable_fixed_values()->set_q(q); - *prior.mutable_fixed_values()->mutable_beta() = beta_proto; - hier->get_mutable_prior()->CopyFrom(prior); - hier->initialize(); - - auto hier2 = hier->clone(); - hier2->sample_prior(); - - bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); - bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); - hier->write_state_to_proto(clusval); - hier2->write_state_to_proto(clusval2); - - ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -} - -TEST(fahierarchy, draw_auto) { - auto hier = std::make_shared(); - bayesmix::FAPrior prior; - Eigen::VectorXd mutilde(0); - bayesmix::Vector mutilde_proto; - bayesmix::to_proto(mutilde, &mutilde_proto); - int q = 2; - double phi = 1.0; - double alpha0 = 5.0; - Eigen::VectorXd beta(0); - bayesmix::Vector beta_proto; - bayesmix::to_proto(beta, &beta_proto); - Eigen::MatrixXd dataset(5, 5); - dataset << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 20, 1, 5, 7, 8, 9; - hier->set_dataset(&dataset); - *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; - prior.mutable_fixed_values()->set_phi(phi); - prior.mutable_fixed_values()->set_alpha0(alpha0); - prior.mutable_fixed_values()->set_q(q); - *prior.mutable_fixed_values()->mutable_beta() = beta_proto; - hier->get_mutable_prior()->CopyFrom(prior); - hier->initialize(); - - auto hier2 = hier->clone(); - hier2->sample_prior(); - - bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); - bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); - hier->write_state_to_proto(clusval); - hier2->write_state_to_proto(clusval2); - - ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()) - << clusval->DebugString() << clusval2->DebugString(); -} - -TEST(fahierarchy, sample_given_data) { - auto hier = std::make_shared(); - bayesmix::FAPrior prior; - Eigen::VectorXd mutilde(4); - mutilde << 3.0, 3.0, 4.0, 1.0; - bayesmix::Vector mutilde_proto; - bayesmix::to_proto(mutilde, &mutilde_proto); - int q = 2; - double phi = 1.0; - double alpha0 = 5.0; - Eigen::VectorXd beta(4); - beta << 3.0, 3.0, 2.0, 2.1; - bayesmix::Vector beta_proto; - bayesmix::to_proto(beta, &beta_proto); - *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; - prior.mutable_fixed_values()->set_phi(phi); - prior.mutable_fixed_values()->set_alpha0(alpha0); - prior.mutable_fixed_values()->set_q(q); - *prior.mutable_fixed_values()->mutable_beta() = beta_proto; - hier->get_mutable_prior()->CopyFrom(prior); - Eigen::MatrixXd dataset(5, 4); - dataset << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 20; - hier->set_dataset(&dataset); - hier->initialize(); +// TEST(fahierarchy, draw) { +// auto hier = std::make_shared(); +// bayesmix::FAPrior prior; +// Eigen::VectorXd mutilde(4); +// mutilde << 3.0, 3.0, 4.0, 1.0; +// bayesmix::Vector mutilde_proto; +// bayesmix::to_proto(mutilde, &mutilde_proto); +// int q = 2; +// double phi = 1.0; +// double alpha0 = 5.0; +// Eigen::VectorXd beta(4); +// beta << 3.0, 3.0, 2.0, 2.1; +// bayesmix::Vector beta_proto; +// bayesmix::to_proto(beta, &beta_proto); +// *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; +// prior.mutable_fixed_values()->set_phi(phi); +// prior.mutable_fixed_values()->set_alpha0(alpha0); +// prior.mutable_fixed_values()->set_q(q); +// *prior.mutable_fixed_values()->mutable_beta() = beta_proto; +// hier->get_mutable_prior()->CopyFrom(prior); +// hier->initialize(); + +// auto hier2 = hier->clone(); +// hier2->sample_prior(); + +// bayesmix::AlgorithmState out; +// bayesmix::AlgorithmState::ClusterState* clusval = +// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 +// = out.add_cluster_states(); hier->write_state_to_proto(clusval); +// hier2->write_state_to_proto(clusval2); + +// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +// } - auto hier2 = hier->clone(); - hier2->add_datum(0, dataset.row(0), false); - hier2->add_datum(1, dataset.row(1), false); - hier2->sample_full_cond(); - bayesmix::AlgorithmState out; - bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); - bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); - hier->write_state_to_proto(clusval); - hier2->write_state_to_proto(clusval2); +// TEST(fahierarchy, draw_auto) { +// auto hier = std::make_shared(); +// bayesmix::FAPrior prior; +// Eigen::VectorXd mutilde(0); +// bayesmix::Vector mutilde_proto; +// bayesmix::to_proto(mutilde, &mutilde_proto); +// int q = 2; +// double phi = 1.0; +// double alpha0 = 5.0; +// Eigen::VectorXd beta(0); +// bayesmix::Vector beta_proto; +// bayesmix::to_proto(beta, &beta_proto); +// Eigen::MatrixXd dataset(5, 5); +// dataset << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, +// 19, +// 20, 1, 5, 7, 8, 9; +// hier->set_dataset(&dataset); +// *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; +// prior.mutable_fixed_values()->set_phi(phi); +// prior.mutable_fixed_values()->set_alpha0(alpha0); +// prior.mutable_fixed_values()->set_q(q); +// *prior.mutable_fixed_values()->mutable_beta() = beta_proto; +// hier->get_mutable_prior()->CopyFrom(prior); +// hier->initialize(); + +// auto hier2 = hier->clone(); +// hier2->sample_prior(); + +// bayesmix::AlgorithmState out; +// bayesmix::AlgorithmState::ClusterState* clusval = +// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 +// = out.add_cluster_states(); hier->write_state_to_proto(clusval); +// hier2->write_state_to_proto(clusval2); + +// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()) +// << clusval->DebugString() << clusval2->DebugString(); +// } - ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -} +// TEST(fahierarchy, sample_given_data) { +// auto hier = std::make_shared(); +// bayesmix::FAPrior prior; +// Eigen::VectorXd mutilde(4); +// mutilde << 3.0, 3.0, 4.0, 1.0; +// bayesmix::Vector mutilde_proto; +// bayesmix::to_proto(mutilde, &mutilde_proto); +// int q = 2; +// double phi = 1.0; +// double alpha0 = 5.0; +// Eigen::VectorXd beta(4); +// beta << 3.0, 3.0, 2.0, 2.1; +// bayesmix::Vector beta_proto; +// bayesmix::to_proto(beta, &beta_proto); +// *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; +// prior.mutable_fixed_values()->set_phi(phi); +// prior.mutable_fixed_values()->set_alpha0(alpha0); +// prior.mutable_fixed_values()->set_q(q); +// *prior.mutable_fixed_values()->mutable_beta() = beta_proto; +// hier->get_mutable_prior()->CopyFrom(prior); +// Eigen::MatrixXd dataset(5, 4); +// dataset << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, +// 19, +// 20; +// hier->set_dataset(&dataset); +// hier->initialize(); + +// auto hier2 = hier->clone(); +// hier2->add_datum(0, dataset.row(0), false); +// hier2->add_datum(1, dataset.row(1), false); +// hier2->sample_full_cond(); +// bayesmix::AlgorithmState out; +// bayesmix::AlgorithmState::ClusterState* clusval = +// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 +// = out.add_cluster_states(); hier->write_state_to_proto(clusval); +// hier2->write_state_to_proto(clusval2); + +// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +// } From 167bde53fe1b70e7e99024dd64c8301eca31935e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 11:26:37 +0100 Subject: [PATCH 148/317] Silenced hierarchies not yet re-factored --- src/hierarchies/{fa_hierarchy.cc => OLD_fa_hierarchy.cc} | 3 +-- src/hierarchies/{fa_hierarchy.h => OLD_fa_hierarchy.h} | 0 .../{lapnig_hierarchy.cc => OLD_lapnig_hierarchy.cc} | 3 +-- src/hierarchies/{lapnig_hierarchy.h => OLD_lapnig_hierarchy.h} | 0 src/proto/algorithm_state.proto | 2 +- 5 files changed, 3 insertions(+), 5 deletions(-) rename src/hierarchies/{fa_hierarchy.cc => OLD_fa_hierarchy.cc} (99%) rename src/hierarchies/{fa_hierarchy.h => OLD_fa_hierarchy.h} (100%) rename src/hierarchies/{lapnig_hierarchy.cc => OLD_lapnig_hierarchy.cc} (99%) rename src/hierarchies/{lapnig_hierarchy.h => OLD_lapnig_hierarchy.h} (100%) diff --git a/src/hierarchies/fa_hierarchy.cc b/src/hierarchies/OLD_fa_hierarchy.cc similarity index 99% rename from src/hierarchies/fa_hierarchy.cc rename to src/hierarchies/OLD_fa_hierarchy.cc index b8473e00e..71c0aceb4 100644 --- a/src/hierarchies/fa_hierarchy.cc +++ b/src/hierarchies/OLD_fa_hierarchy.cc @@ -1,5 +1,3 @@ -#include "fa_hierarchy.h" - #include #include @@ -9,6 +7,7 @@ #include #include "algorithm_state.pb.h" +#include "fa_hierarchy.h" #include "hierarchy_prior.pb.h" #include "ls_state.pb.h" #include "src/utils/proto_utils.h" diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/OLD_fa_hierarchy.h similarity index 100% rename from src/hierarchies/fa_hierarchy.h rename to src/hierarchies/OLD_fa_hierarchy.h diff --git a/src/hierarchies/lapnig_hierarchy.cc b/src/hierarchies/OLD_lapnig_hierarchy.cc similarity index 99% rename from src/hierarchies/lapnig_hierarchy.cc rename to src/hierarchies/OLD_lapnig_hierarchy.cc index b0d479244..1d512de1a 100644 --- a/src/hierarchies/lapnig_hierarchy.cc +++ b/src/hierarchies/OLD_lapnig_hierarchy.cc @@ -1,5 +1,3 @@ -#include "lapnig_hierarchy.h" - #include #include @@ -10,6 +8,7 @@ #include "algorithm_state.pb.h" #include "hierarchy_prior.pb.h" +#include "lapnig_hierarchy.h" #include "ls_state.pb.h" #include "src/utils/rng.h" diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/OLD_lapnig_hierarchy.h similarity index 100% rename from src/hierarchies/lapnig_hierarchy.h rename to src/hierarchies/OLD_lapnig_hierarchy.h diff --git a/src/proto/algorithm_state.proto b/src/proto/algorithm_state.proto index 3fa5169aa..0f387eda9 100644 --- a/src/proto/algorithm_state.proto +++ b/src/proto/algorithm_state.proto @@ -25,7 +25,7 @@ message AlgorithmState { MultiLSState multi_ls_state = 2; // State of a multivariate location-scale family LinRegUniLSState lin_reg_uni_ls_state = 4; // State of a linear regression univariate location-scale family Vector general_state = 5; // Just a vector of doubles - FAState fa_state = 6; // State of a Mixture of Factor Analysers + // FAState fa_state = 6; // State of a Mixture of Factor Analysers } int32 cardinality = 3; // How many observations are in this cluster } From 079d63e7e6a4b0173b49e7a33868817aad6e7513 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 12:20:09 +0100 Subject: [PATCH 149/317] Update includes --- src/algorithms/split_and_merge_algorithm.h | 1 - src/mixings/mixture_finite_mixing.h | 1 - 2 files changed, 2 deletions(-) diff --git a/src/algorithms/split_and_merge_algorithm.h b/src/algorithms/split_and_merge_algorithm.h index e19cb9525..61d5c0d7d 100644 --- a/src/algorithms/split_and_merge_algorithm.h +++ b/src/algorithms/split_and_merge_algorithm.h @@ -1,7 +1,6 @@ #ifndef BAYESMIX_ALGORITHMS_SPLIT_AND_MERGE_ALGORITHM_H_ #define BAYESMIX_ALGORITHMS_SPLIT_AND_MERGE_ALGORITHM_H_ -#include #include #include diff --git a/src/mixings/mixture_finite_mixing.h b/src/mixings/mixture_finite_mixing.h index 63de8f81e..3b85a4d10 100644 --- a/src/mixings/mixture_finite_mixing.h +++ b/src/mixings/mixture_finite_mixing.h @@ -3,7 +3,6 @@ #include -#include #include #include #include From 00cd9c546d119b3b56e2dc338cf5d109c4b8f21f Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 12:21:04 +0100 Subject: [PATCH 150/317] Improvements --- src/hierarchies/priors/base_prior_model.h | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index b231af091..9688c6248 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -74,7 +74,7 @@ class BasePriorModel : public AbstractPriorModel { virtual google::protobuf::Message *get_mutable_prior() override; - HyperParams get_hypers() const { return hypers; } + HyperParams get_hypers() const { return *hypers; } HyperParams get_posterior_hypers() const { return post_hypers; } @@ -91,6 +91,8 @@ class BasePriorModel : public AbstractPriorModel { void create_empty_prior() { prior.reset(new Prior); } + void create_empty_hypers() { hypers.reset(new HyperParams); } + bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< @@ -109,7 +111,7 @@ class BasePriorModel : public AbstractPriorModel { const bayesmix::AlgorithmState::ClusterState &>(state_); } - HyperParams hypers; + std::shared_ptr hypers; HyperParams post_hypers; std::shared_ptr prior; }; @@ -134,7 +136,13 @@ BasePriorModel::deep_clone() const { new_prior->CopyFrom(*prior.get()); out->get_mutable_prior()->CopyFrom(*new_prior.get()); - // C'è da farlo anche su Hypers, ma va cambiata della roba! + // HyperParams Deep-clone + out->create_empty_hypers(); + auto curr_hypers_proto = get_hypers_proto(); + out->set_hypers_from_proto(*curr_hypers_proto.get()); + + // Initialization of Deep-cloned object + out->initialize(); return out; } From e911c1afed3dea047d68cc5aa4d9a6a7ba41a3d3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 12:21:38 +0100 Subject: [PATCH 151/317] hypers is back to pointer --- src/hierarchies/priors/nig_prior_model.cc | 76 +++++++------- src/hierarchies/priors/nig_prior_model.h | 6 +- src/hierarchies/priors/nw_prior_model.cc | 109 +++++++++++---------- src/hierarchies/priors/nxig_prior_model.cc | 39 ++++---- 4 files changed, 115 insertions(+), 115 deletions(-) diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index a92af1b7b..67009462c 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -3,34 +3,34 @@ void NIGPriorModel::initialize_hypers() { if (prior->has_fixed_values()) { // Set values - hypers.mean = prior->fixed_values().mean(); - hypers.var_scaling = prior->fixed_values().var_scaling(); - hypers.shape = prior->fixed_values().shape(); - hypers.scale = prior->fixed_values().scale(); + hypers->mean = prior->fixed_values().mean(); + hypers->var_scaling = prior->fixed_values().var_scaling(); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); // Check validity - if (hypers.var_scaling <= 0) { + if (hypers->var_scaling <= 0) { throw std::invalid_argument("Variance-scaling parameter must be > 0"); } - if (hypers.shape <= 0) { + if (hypers->shape <= 0) { throw std::invalid_argument("Shape parameter must be > 0"); } - if (hypers.scale <= 0) { + if (hypers->scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } } else if (prior->has_normal_mean_prior()) { // Set initial values - hypers.mean = prior->normal_mean_prior().mean_prior().mean(); - hypers.var_scaling = prior->normal_mean_prior().var_scaling(); - hypers.shape = prior->normal_mean_prior().shape(); - hypers.scale = prior->normal_mean_prior().scale(); + hypers->mean = prior->normal_mean_prior().mean_prior().mean(); + hypers->var_scaling = prior->normal_mean_prior().var_scaling(); + hypers->shape = prior->normal_mean_prior().shape(); + hypers->scale = prior->normal_mean_prior().scale(); // Check validity - if (hypers.var_scaling <= 0) { + if (hypers->var_scaling <= 0) { throw std::invalid_argument("Variance-scaling parameter must be > 0"); } - if (hypers.shape <= 0) { + if (hypers->shape <= 0) { throw std::invalid_argument("Shape parameter must be > 0"); } - if (hypers.scale <= 0) { + if (hypers->scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } } else if (prior->has_ngg_prior()) { @@ -66,10 +66,10 @@ void NIGPriorModel::initialize_hypers() { throw std::invalid_argument("Shape parameter must be > 0"); } // Set initial values - hypers.mean = mu00; - hypers.var_scaling = alpha00 / beta00; - hypers.shape = alpha0; - hypers.scale = a00 / b00; + hypers->mean = mu00; + hypers->var_scaling = alpha00 / beta00; + hypers->shape = alpha0; + hypers->scale = a00 / b00; } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } @@ -78,16 +78,16 @@ void NIGPriorModel::initialize_hypers() { double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { auto &state = downcast_state(state_).uni_ls_state(); double target = - stan::math::normal_lpdf(state.mean(), hypers.mean, - sqrt(state.var() / hypers.var_scaling)) + - stan::math::inv_gamma_lpdf(state.var(), hypers.shape, hypers.scale); + stan::math::normal_lpdf(state.mean(), hypers->mean, + sqrt(state.var() / hypers->var_scaling)) + + stan::math::inv_gamma_lpdf(state.var(), hypers->shape, hypers->scale); return target; } std::shared_ptr NIGPriorModel::sample( bool use_post_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - Hyperparams::NIG params = use_post_hypers ? post_hypers : hypers; + Hyperparams::NIG params = use_post_hypers ? post_hypers : *hypers; double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); double mean = stan::math::normal_rng(params.mean, sqrt(var / params.var_scaling), rng); @@ -123,7 +123,7 @@ void NIGPriorModel::update_hypers( double mu_n = num / prec; double sig2_n = 1 / prec; // Update hyperparameters with posterior random sampling - hypers.mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); + hypers->mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); } else if (prior->has_ngg_prior()) { // Get hyperparameters: // for mu0 @@ -144,20 +144,20 @@ void NIGPriorModel::update_hypers( double var = st.uni_ls_state().var(); b_n += 1 / var; num += mean / var; - beta_n += (hypers.mean - mean) * (hypers.mean - mean) / var; + beta_n += (hypers->mean - mean) * (hypers->mean - mean) / var; } - double var = hypers.var_scaling * b_n + 1 / sig200; + double var = hypers->var_scaling * b_n + 1 / sig200; b_n += b00; - num = hypers.var_scaling * num + mu00 / sig200; + num = hypers->var_scaling * num + mu00 / sig200; beta_n = beta00 + 0.5 * beta_n; double sig_n = 1 / var; double mu_n = num / var; double alpha_n = alpha00 + 0.5 * states.size(); - double a_n = a00 + states.size() * hypers.shape; + double a_n = a00 + states.size() * hypers->shape; // Update hyperparameters with posterior random Gibbs sampling - hypers.mean = stan::math::normal_rng(mu_n, sig_n, rng); - hypers.var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); - hypers.scale = stan::math::gamma_rng(a_n, b_n, rng); + hypers->mean = stan::math::normal_rng(mu_n, sig_n, rng); + hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); + hypers->scale = stan::math::gamma_rng(a_n, b_n, rng); } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } @@ -166,19 +166,19 @@ void NIGPriorModel::update_hypers( void NIGPriorModel::set_hypers_from_proto( const google::protobuf::Message &hypers_) { auto &hyperscast = downcast_hypers(hypers_).nnig_state(); - hypers.mean = hyperscast.mean(); - hypers.var_scaling = hyperscast.var_scaling(); - hypers.scale = hyperscast.scale(); - hypers.shape = hyperscast.shape(); + hypers->mean = hyperscast.mean(); + hypers->var_scaling = hyperscast.var_scaling(); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); } std::shared_ptr NIGPriorModel::get_hypers_proto() const { bayesmix::NIGDistribution hypers_; - hypers_.set_mean(hypers.mean); - hypers_.set_var_scaling(hypers.var_scaling); - hypers_.set_shape(hypers.shape); - hypers_.set_scale(hypers.scale); + hypers_.set_mean(hypers->mean); + hypers_.set_var_scaling(hypers->var_scaling); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); auto out = std::make_shared(); out->mutable_nnig_state()->CopyFrom(hypers_); diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index daa143477..0cda78004 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -30,9 +30,9 @@ class NIGPriorModel : public BasePriorModelmean, + sqrt(var / hypers->var_scaling)) + + stan::math::inv_gamma_lpdf(var, hypers->shape, hypers->scale); return lpdf + log_det_jac; } diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc index 6bdecb724..ab7090dc8 100644 --- a/src/hierarchies/priors/nw_prior_model.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -1,26 +1,26 @@ #include "nw_prior_model.h" +#include "src/utils/distributions.h" #include "src/utils/eigen_utils.h" #include "src/utils/proto_utils.h" -#include "src/utils/distributions.h" void NWPriorModel::initialize_hypers() { if (prior->has_fixed_values()) { // Set values - hypers.mean = bayesmix::to_eigen(prior->fixed_values().mean()); - dim = hypers.mean.size(); - hypers.var_scaling = prior->fixed_values().var_scaling(); - hypers.scale = bayesmix::to_eigen(prior->fixed_values().scale()); - hypers.deg_free = prior->fixed_values().deg_free(); + hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); + dim = hypers->mean.size(); + hypers->var_scaling = prior->fixed_values().var_scaling(); + hypers->scale = bayesmix::to_eigen(prior->fixed_values().scale()); + hypers->deg_free = prior->fixed_values().deg_free(); // Check validity - if (hypers.var_scaling <= 0) { + if (hypers->var_scaling <= 0) { throw std::invalid_argument("Variance-scaling parameter must be > 0"); } - if (dim != hypers.scale.rows()) { + if (dim != hypers->scale.rows()) { throw std::invalid_argument( "Hyperparameters dimensions are not consistent"); } - if (hypers.deg_free <= dim - 1) { + if (hypers->deg_free <= dim - 1) { throw std::invalid_argument("Degrees of freedom parameter is not valid"); } } @@ -51,10 +51,10 @@ void NWPriorModel::initialize_hypers() { throw std::invalid_argument("Degrees of freedom parameter is not valid"); } // Set initial values - hypers.mean = mu00; - hypers.var_scaling = lambda0; - hypers.scale = tau0; - hypers.deg_free = nu0; + hypers->mean = mu00; + hypers->var_scaling = lambda0; + hypers->scale = tau0; + hypers->deg_free = nu0; } else if (prior->has_ngiw_prior()) { @@ -99,36 +99,37 @@ void NWPriorModel::initialize_hypers() { throw std::invalid_argument("Degrees of freedom parameter is not valid"); } // Set initial values - hypers.mean = mu00; - hypers.var_scaling = alpha00 / beta00; - hypers.scale = tau00 / (nu00 + dim + 1); - hypers.deg_free = nu0; + hypers->mean = mu00; + hypers->var_scaling = alpha00 / beta00; + hypers->scale = tau00 / (nu00 + dim + 1); + hypers->deg_free = nu0; } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } - hypers.scale_inv = stan::math::inverse_spd(hypers.scale); - hypers.scale_chol = Eigen::LLT(hypers.scale).matrixU(); + hypers->scale_inv = stan::math::inverse_spd(hypers->scale); + hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); } double NWPriorModel::lpdf(const google::protobuf::Message &state_) { auto &state = downcast_state(state_).multi_ls_state(); Eigen::VectorXd mean = bayesmix::to_eigen(state.mean()); Eigen::MatrixXd prec = bayesmix::to_eigen(state.prec()); - double target = stan::math::multi_normal_prec_lpdf(mean, hypers.mean, prec * hypers.var_scaling) + - stan::math::wishart_lpdf(prec, hypers.deg_free, hypers.scale); + double target = + stan::math::multi_normal_prec_lpdf(mean, hypers->mean, + prec * hypers->var_scaling) + + stan::math::wishart_lpdf(prec, hypers->deg_free, hypers->scale); return target; } std::shared_ptr NWPriorModel::sample( bool use_post_hypers) { - auto &rng = bayesmix::Rng::Instance().get(); - - Hyperparams::NW params = use_post_hypers ? post_hypers : hypers; + + Hyperparams::NW params = use_post_hypers ? post_hypers : *hypers; Eigen::MatrixXd tau_new = stan::math::wishart_rng(params.deg_free, params.scale, rng); - + // Update state State::MultiLS out; out.mean = stan::math::multi_normal_prec_rng( @@ -146,15 +147,15 @@ std::shared_ptr NWPriorModel::sample( bayesmix::AlgorithmState::ClusterState state; state.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); state.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); - state.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom(prec_chol_proto); + state.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom( + prec_chol_proto); return std::make_shared(state); }; void NWPriorModel::update_hypers( const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - + if (prior->has_fixed_values()) { return; } @@ -175,13 +176,13 @@ void NWPriorModel::update_hypers( prec += prec_i; num += prec_i * bayesmix::to_eigen(st.multi_ls_state().mean()); } - prec = hypers.var_scaling * prec + sigma00inv; - num = hypers.var_scaling * num + sigma00inv * mu00; + prec = hypers->var_scaling * prec + sigma00inv; + num = hypers->var_scaling * num + sigma00inv * mu00; Eigen::VectorXd mu_n = prec.llt().solve(num); // Update hyperparameters with posterior sampling - hypers.mean = stan::math::multi_normal_prec_rng(mu_n, prec, rng); + hypers->mean = stan::math::multi_normal_prec_rng(mu_n, prec, rng); } - + else if (prior->has_ngiw_prior()) { // Get hyperparameters: // for mu0 @@ -207,22 +208,22 @@ void NWPriorModel::update_hypers( tau_n += prec; num += prec * mean; beta_n += - (hypers.mean - mean).transpose() * prec * (hypers.mean - mean); + (hypers->mean - mean).transpose() * prec * (hypers->mean - mean); } - Eigen::MatrixXd prec_n = hypers.var_scaling * tau_n + sigma00inv; + Eigen::MatrixXd prec_n = hypers->var_scaling * tau_n + sigma00inv; tau_n += tau00; - num = hypers.var_scaling * num + sigma00inv * mu00; + num = hypers->var_scaling * num + sigma00inv * mu00; beta_n = beta00 + 0.5 * beta_n; Eigen::MatrixXd sig_n = stan::math::inverse_spd(prec_n); Eigen::VectorXd mu_n = sig_n * num; double alpha_n = alpha00 + 0.5 * states.size(); - double nu_n = nu00 + states.size() * hypers.deg_free; + double nu_n = nu00 + states.size() * hypers->deg_free; // Update hyperparameters with posterior random Gibbs sampling - hypers.mean = stan::math::multi_normal_rng(mu_n, sig_n, rng); - hypers.var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); - hypers.scale = stan::math::inv_wishart_rng(nu_n, tau_n, rng); - hypers.scale_inv = stan::math::inverse_spd(hypers.scale); - hypers.scale_chol = Eigen::LLT(hypers.scale).matrixU(); + hypers->mean = stan::math::multi_normal_rng(mu_n, sig_n, rng); + hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); + hypers->scale = stan::math::inv_wishart_rng(nu_n, tau_n, rng); + hypers->scale_inv = stan::math::inverse_spd(hypers->scale); + hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); } else { @@ -233,36 +234,36 @@ void NWPriorModel::update_hypers( void NWPriorModel::set_hypers_from_proto( const google::protobuf::Message &hypers_) { auto &hyperscast = downcast_hypers(hypers_).nnw_state(); - hypers.mean = bayesmix::to_eigen(hyperscast.mean()); - hypers.var_scaling = hyperscast.var_scaling(); - hypers.deg_free = hyperscast.deg_free(); - hypers.scale = bayesmix::to_eigen(hyperscast.scale()); - hypers.scale_inv = stan::math::inverse_spd(hypers.scale); - hypers.scale_chol = Eigen::LLT(hypers.scale).matrixU(); + hypers->mean = bayesmix::to_eigen(hyperscast.mean()); + hypers->var_scaling = hyperscast.var_scaling(); + hypers->deg_free = hyperscast.deg_free(); + hypers->scale = bayesmix::to_eigen(hyperscast.scale()); + hypers->scale_inv = stan::math::inverse_spd(hypers->scale); + hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); } std::shared_ptr NWPriorModel::get_hypers_proto() const { - // Translate to proto bayesmix::Vector mean_proto; bayesmix::Matrix scale_proto; - bayesmix::to_proto(hypers.mean, &mean_proto); - bayesmix::to_proto(hypers.scale, &scale_proto); + bayesmix::to_proto(hypers->mean, &mean_proto); + bayesmix::to_proto(hypers->scale, &scale_proto); // Make output state and return auto out = std::make_shared(); out->mutable_nnw_state()->mutable_mean()->CopyFrom(mean_proto); - out->mutable_nnw_state()->set_var_scaling(hypers.var_scaling); - out->mutable_nnw_state()->set_deg_free(hypers.deg_free); + out->mutable_nnw_state()->set_var_scaling(hypers->var_scaling); + out->mutable_nnw_state()->set_deg_free(hypers->deg_free); out->mutable_nnw_state()->mutable_scale()->CopyFrom(scale_proto); return out; } -void NWPriorModel::write_prec_to_state(const Eigen::MatrixXd &prec_, State::MultiLS *out) { +void NWPriorModel::write_prec_to_state(const Eigen::MatrixXd &prec_, + State::MultiLS *out) { out->prec = prec_; // Update prec utilities out->prec_chol = Eigen::LLT(prec_).matrixU(); Eigen::VectorXd diag = out->prec_chol.diagonal(); out->prec_logdet = 2 * log(diag.array()).sum(); -} \ No newline at end of file +} diff --git a/src/hierarchies/priors/nxig_prior_model.cc b/src/hierarchies/priors/nxig_prior_model.cc index 2729fe67e..eec6f3871 100644 --- a/src/hierarchies/priors/nxig_prior_model.cc +++ b/src/hierarchies/priors/nxig_prior_model.cc @@ -3,18 +3,18 @@ void NxIGPriorModel::initialize_hypers() { if (prior->has_fixed_values()) { // Set values - hypers.mean = prior->fixed_values().mean(); - hypers.var = prior->fixed_values().var(); - hypers.shape = prior->fixed_values().shape(); - hypers.scale = prior->fixed_values().scale(); + hypers->mean = prior->fixed_values().mean(); + hypers->var = prior->fixed_values().var(); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); // Check validity - if (hypers.var <= 0) { + if (hypers->var <= 0) { throw std::invalid_argument("Variance parameter must be > 0"); } - if (hypers.shape <= 0) { + if (hypers->shape <= 0) { throw std::invalid_argument("Shape parameter must be > 0"); } - if (hypers.scale <= 0) { + if (hypers->scale <= 0) { throw std::invalid_argument("scale parameter must be > 0"); } } else { @@ -25,19 +25,18 @@ void NxIGPriorModel::initialize_hypers() { double NxIGPriorModel::lpdf(const google::protobuf::Message &state_) { auto &state = downcast_state(state_).uni_ls_state(); double target = - stan::math::normal_lpdf(state.mean(), hypers.mean,sqrt(hypers.var)) + - stan::math::inv_gamma_lpdf(state.var(), hypers.shape, hypers.scale); + stan::math::normal_lpdf(state.mean(), hypers->mean, sqrt(hypers->var)) + + stan::math::inv_gamma_lpdf(state.var(), hypers->shape, hypers->scale); return target; } std::shared_ptr NxIGPriorModel::sample( bool use_post_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - Hyperparams::NxIG params = use_post_hypers ? post_hypers : hypers; + Hyperparams::NxIG params = use_post_hypers ? post_hypers : *hypers; double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - double mean = - stan::math::normal_rng(params.mean, sqrt(params.var), rng); + double mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); bayesmix::AlgorithmState::ClusterState state; state.mutable_uni_ls_state()->set_mean(mean); @@ -57,19 +56,19 @@ void NxIGPriorModel::update_hypers( void NxIGPriorModel::set_hypers_from_proto( const google::protobuf::Message &hypers_) { auto &hyperscast = downcast_hypers(hypers_).nnxig_state(); - hypers.mean = hyperscast.mean(); - hypers.var = hyperscast.var(); - hypers.scale = hyperscast.scale(); - hypers.shape = hyperscast.shape(); + hypers->mean = hyperscast.mean(); + hypers->var = hyperscast.var(); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); } std::shared_ptr NxIGPriorModel::get_hypers_proto() const { bayesmix::NxIGDistribution hypers_; - hypers_.set_mean(hypers.mean); - hypers_.set_var(hypers.var); - hypers_.set_shape(hypers.shape); - hypers_.set_scale(hypers.scale); + hypers_.set_mean(hypers->mean); + hypers_.set_var(hypers->var); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); auto out = std::make_shared(); out->mutable_nnxig_state()->CopyFrom(hypers_); From 6dda1c1ae9048b73e2a6145f3071000868137301 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 16:23:43 +0100 Subject: [PATCH 152/317] Cleaned notebook --- python/notebooks/gaussian_mix_uni.ipynb | 35 ++----------------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/python/notebooks/gaussian_mix_uni.ipynb b/python/notebooks/gaussian_mix_uni.ipynb index 58aea7ca5..6e74c6142 100644 --- a/python/notebooks/gaussian_mix_uni.ipynb +++ b/python/notebooks/gaussian_mix_uni.ipynb @@ -236,30 +236,6 @@ "algo_names = [\"Neal2\", \"Neal3\", \"Neal8\"]" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dens_grid = np.linspace(-10, 10, 1000)\n", - "eval_dens, a, b, c = run_mcmc(\n", - " \"NNIG\", \"DP\", data, g0_params[0], dp_params[0], neal8_algo, dens_grid,\n", - " return_clusters=True, return_num_clusters=True,\n", - " return_best_clus=True)\n", - "\n", - "plt.plot(dens_grid, np.exp(np.mean(eval_dens, axis=0)))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ciao" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -473,13 +449,6 @@ "source": [ "np.var(data)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -487,7 +456,7 @@ "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -501,7 +470,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.10" + "version": "3.8.10" } }, "nbformat": 4, From 41690946d49b670fb51d4be1d3eeff4d5371043c Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 16:24:08 +0100 Subject: [PATCH 153/317] Commented debug string --- executables/run_mcmc.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/executables/run_mcmc.cc b/executables/run_mcmc.cc index 776bd3f56..9d1e4c8e0 100644 --- a/executables/run_mcmc.cc +++ b/executables/run_mcmc.cc @@ -168,8 +168,8 @@ int main(int argc, char *argv[]) { bayesmix::read_proto_from_file(args.get("--hier-args"), hier->get_mutable_prior()); - std::cout << "hier->prior: \n" - << hier->get_mutable_prior()->DebugString() << std::endl; + // std::cout << "hier->prior: \n" + // << hier->get_mutable_prior()->DebugString() << std::endl; hier->initialize(); From 99998698ce68753789b741dd2deb42dcf7708528 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 16:24:31 +0100 Subject: [PATCH 154/317] Fixed bug in initialization --- src/hierarchies/priors/base_prior_model.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 9688c6248..504567bde 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -168,6 +168,7 @@ void BasePriorModel::write_hypers_to_proto( template void BasePriorModel::initialize() { check_prior_is_set(); + create_empty_hypers(); initialize_hypers(); } From abe3b0905118ae5b80b9fbbf2eb9f52da842c8e9 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 21:24:19 +0100 Subject: [PATCH 155/317] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1c62222a9..c80e9f05c 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,8 @@ To build the executable for the main file `run_mcmc.cc`, please use the followin mkdir build cd build -cmake .. -DDISABLE_DOCS=on -DDISABLE_BENCHMARKS=on -DDISABLE_TESTS=on -make run +cmake .. -DDISABLE_DOCS=ON -DDISABLE_BENCHMARKS=ON -DDISABLE_TESTS=ON +make run_mcmc cd .. ``` From eb3334ab62e2019de6f4a5a28d58eb0f520ce205 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 21:25:04 +0100 Subject: [PATCH 156/317] Update README.md --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index c80e9f05c..4c7cf0959 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,6 @@ To build the executable for the main file `run_mcmc.cc`, please use the followin ```shell mkdir build cd build - cmake .. -DDISABLE_DOCS=ON -DDISABLE_BENCHMARKS=ON -DDISABLE_TESTS=ON make run_mcmc cd .. From 2cd5cdb925bf1b2adb3b47eb66adb1939f00d26c Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Feb 2022 21:32:05 +0100 Subject: [PATCH 157/317] Update README.md --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index d924dacd5..94ea57d87 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,13 @@ Current state of the software: - `bayesmix` performs inference for mixture models of the kind + + + where P is either the Dirichlet process or the Pitman--Yor process From e1aac142d9a91c33eb578c3545c64e7349e7d884 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 1 Mar 2022 12:21:17 +0100 Subject: [PATCH 158/317] Fix includes --- src/hierarchies/likelihoods/states/multi_ls_state.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index ae35d7eb8..406d008fe 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -1,7 +1,7 @@ #ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ #define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ -#include +// #include #include #include From 6f3a309b01f414c7ff9bb9621421369bae305e0e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 1 Mar 2022 12:21:56 +0100 Subject: [PATCH 159/317] No nullptr at instanciation --- src/hierarchies/priors/base_prior_model.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 504567bde..518d687e9 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -111,7 +111,7 @@ class BasePriorModel : public AbstractPriorModel { const bayesmix::AlgorithmState::ClusterState &>(state_); } - std::shared_ptr hypers; + std::shared_ptr hypers = std::make_shared(); HyperParams post_hypers; std::shared_ptr prior; }; From 67e2478ac223c15e0d5caf0def71704a082158c3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 1 Mar 2022 12:22:42 +0100 Subject: [PATCH 160/317] Add laplace likelihood (ONGOING) --- src/hierarchies/likelihoods/CMakeLists.txt | 2 + .../likelihoods/laplace_likelihood.cc | 53 ++++++++++++++++++ .../likelihoods/laplace_likelihood.h | 55 +++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 src/hierarchies/likelihoods/laplace_likelihood.cc create mode 100644 src/hierarchies/likelihoods/laplace_likelihood.h diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 4a642dd9d..b9a6a25fa 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -5,6 +5,8 @@ target_sources(bayesmix PUBLIC uni_norm_likelihood.cc multi_norm_likelihood.h multi_norm_likelihood.cc + laplace_likelihood.h + laplace_likelihood.cc ) add_subdirectory(states) diff --git a/src/hierarchies/likelihoods/laplace_likelihood.cc b/src/hierarchies/likelihoods/laplace_likelihood.cc new file mode 100644 index 000000000..7be031d13 --- /dev/null +++ b/src/hierarchies/likelihoods/laplace_likelihood.cc @@ -0,0 +1,53 @@ +#include "laplace_likelihood.h" + +double LaplaceLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::double_exponential_lpdf(datum(0), state.mean, + stan::math::sqrt(state.var / 2)); +} + +void LaplaceLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + bool add) { + if (add) { + // sum_abs_diff_curr += std::abs(state.mean - datum(0, 0)); + cluster_data_values.push_back(datum); + } else { + // sum_abs_diff_curr -= std::abs(state.mean - datum(0, 0)); + auto it = std::find(cluster_data_values.begin(), cluster_data_values.end(), + datum); + cluster_data_values.erase(it); + } +} + +void LaplaceLikelihood::set_state_from_proto( + const google::protobuf::Message &state_, bool update_card) { + auto &statecast = downcast_state(state_); + state.mean = statecast.uni_ls_state().mean(); + state.var = statecast.uni_ls_state().var(); + if (update_card) set_card(statecast.cardinality()); +} + +std::shared_ptr +LaplaceLikelihood::get_state_proto() const { + auto out = std::make_shared(); + out->mutable_uni_ls_state()->set_mean(state.mean); + out->mutable_uni_ls_state()->set_var(state.var); + return out; +} + +void LaplaceLikelihood::clear_summary_statistics() { + cluster_data_values.clear(); + sum_abs_diff_curr = 0; + sum_abs_diff_prop = 0; +} + +// double UniNormLikelihood::cluster_lpdf_from_unconstrained( +// Eigen::VectorXd unconstrained_params) { +// assert(unconstrained_params.size() == 2); +// double mean = unconstrained_params(0); +// double var = std::exp(unconstrained_params(1)); +// double out = -(data_sum_squares - 2 * mean * data_sum + card * mean * +// mean) / +// (2 * var); +// out -= card * 0.5 * std::log(stan::math::TWO_PI * var); +// return out; +// } diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h new file mode 100644 index 000000000..84a63518c --- /dev/null +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -0,0 +1,55 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_LAPLACE_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_LAPLACE_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +class LaplaceLikelihood + : public BaseLikelihood { + public: + LaplaceLikelihood() = default; + ~LaplaceLikelihood() = default; + bool is_multivariate() const override { return false; }; + bool is_dependent() const override { return false; }; + void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) override; + void clear_summary_statistics() override; + + template + T cluster_lpdf_from_unconstrained( + const Eigen::Matrix &unconstrained_params) const { + assert(unconstrained_params.size() == 2); + T mean = unconstrained_params(0); + T var = stan::math::positive_constrain(unconstrained_params(1)); + T out = 0.; + for (auto it = cluster_data_values.begin(); + it != cluster_data_values.end(); ++it) { + out += stan::math::double_exponential_lpdf(*it, mean, + stan::math::sqrt(var / 2)); + } + return out; + } + + std::shared_ptr get_state_proto() + const override; + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + + //! Set of values of data points belonging to this cluster + std::list cluster_data_values; + //! Sum of absolute differences for current params + // double sum_abs_diff_curr = 0; + //! Sum of absolute differences for proposal params + // double sum_abs_diff_prop = 0; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_LAPLACE_LIKELIHOOD_H_ From 14fee74482711713e07483448efc2fce5efefdfa Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 1 Mar 2022 17:34:16 +0100 Subject: [PATCH 161/317] LapNIG hierarchy re-factored --- .../lapnig_hierarchy.cc} | 3 +- .../lapnig_hierarchy.h} | 0 src/hierarchies/lapnig_hierarchy.h | 44 +++++++++++++++++++ src/hierarchies/load_hierarchies.h | 10 ++--- 4 files changed, 51 insertions(+), 6 deletions(-) rename src/hierarchies/{OLD_lapnig_hierarchy.cc => .old/lapnig_hierarchy.cc} (99%) rename src/hierarchies/{OLD_lapnig_hierarchy.h => .old/lapnig_hierarchy.h} (100%) create mode 100644 src/hierarchies/lapnig_hierarchy.h diff --git a/src/hierarchies/OLD_lapnig_hierarchy.cc b/src/hierarchies/.old/lapnig_hierarchy.cc similarity index 99% rename from src/hierarchies/OLD_lapnig_hierarchy.cc rename to src/hierarchies/.old/lapnig_hierarchy.cc index 1d512de1a..b0d479244 100644 --- a/src/hierarchies/OLD_lapnig_hierarchy.cc +++ b/src/hierarchies/.old/lapnig_hierarchy.cc @@ -1,3 +1,5 @@ +#include "lapnig_hierarchy.h" + #include #include @@ -8,7 +10,6 @@ #include "algorithm_state.pb.h" #include "hierarchy_prior.pb.h" -#include "lapnig_hierarchy.h" #include "ls_state.pb.h" #include "src/utils/rng.h" diff --git a/src/hierarchies/OLD_lapnig_hierarchy.h b/src/hierarchies/.old/lapnig_hierarchy.h similarity index 100% rename from src/hierarchies/OLD_lapnig_hierarchy.h rename to src/hierarchies/.old/lapnig_hierarchy.h diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h new file mode 100644 index 000000000..a7f049115 --- /dev/null +++ b/src/hierarchies/lapnig_hierarchy.h @@ -0,0 +1,44 @@ +#ifndef BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ + +// #include + +// #include +// #include +// #include + +// #include "algorithm_state.pb.h" +// #include "conjugate_hierarchy.h" +#include "hierarchy_id.pb.h" +// #include "hierarchy_prior.pb.h" + +#include "base_hierarchy.h" +#include "likelihoods/laplace_likelihood.h" +#include "priors/nxig_prior_model.h" +#include "updaters/mala_updater.h" + +class LapNIGHierarchy + : public BaseHierarchy { + public: + LapNIGHierarchy() = default; + ~LapNIGHierarchy() = default; + + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::LapNIG; + } + + void set_default_updater() { updater = std::make_shared(); } + + void initialize_state() override { + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::UniLS state; + state.mean = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + like->set_state(state); + }; +}; + +#endif // BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index a7774dd7a..602487b80 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -7,7 +7,7 @@ #include "abstract_hierarchy.h" // #include "fa_hierarchy.h" #include "hierarchy_id.pb.h" -// #include "lapnig_hierarchy.h" +#include "lapnig_hierarchy.h" // #include "lin_reg_uni_hierarchy.h" #include "nnig_hierarchy.h" #include "nnw_hierarchy.h" @@ -40,16 +40,16 @@ __attribute__((constructor)) static void load_hierarchies() { // Builder FAbuilder = []() { // return std::make_shared(); // }; - // Builder LapNIGbuilder = []() { - // return std::make_shared(); - // }; + Builder LapNIGbuilder = []() { + return std::make_shared(); + }; factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder); factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); // factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); // factory.add_builder(FAHierarchy().get_id(), FAbuilder); - // factory.add_builder(LapNIGHierarchy().get_id(), LapNIGbuilder); + factory.add_builder(LapNIGHierarchy().get_id(), LapNIGbuilder); } #endif // BAYESMIX_HIERARCHIES_LOAD_HIERARCHIES_H_ From 6cce465bf926cc70c5278b063f8705bd4930724b Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 1 Mar 2022 17:34:50 +0100 Subject: [PATCH 162/317] Add lapnig_hierarchy.h --- src/hierarchies/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index fbafc64c0..3ff8486f7 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -11,7 +11,7 @@ target_sources(bayesmix # lin_reg_uni_hierarchy.cc # fa_hierarchy.h # fa_hierarchy.cc - # lapnig_hierarchy.h + lapnig_hierarchy.h # lapnig_hierarchy.cc ) From b49f82873eb7a7d20127ffa697c8dcd87c586c27 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 1 Mar 2022 17:35:24 +0100 Subject: [PATCH 163/317] Minor code fix --- src/hierarchies/likelihoods/laplace_likelihood.cc | 8 ++++---- src/hierarchies/likelihoods/laplace_likelihood.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/likelihoods/laplace_likelihood.cc b/src/hierarchies/likelihoods/laplace_likelihood.cc index 7be031d13..61bad8429 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.cc +++ b/src/hierarchies/likelihoods/laplace_likelihood.cc @@ -1,8 +1,8 @@ #include "laplace_likelihood.h" double LaplaceLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { - return stan::math::double_exponential_lpdf(datum(0), state.mean, - stan::math::sqrt(state.var / 2)); + return stan::math::double_exponential_lpdf( + datum(0), state.mean, stan::math::sqrt(state.var / 2.0)); } void LaplaceLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, @@ -36,8 +36,8 @@ LaplaceLikelihood::get_state_proto() const { void LaplaceLikelihood::clear_summary_statistics() { cluster_data_values.clear(); - sum_abs_diff_curr = 0; - sum_abs_diff_prop = 0; + // sum_abs_diff_curr = 0; + // sum_abs_diff_prop = 0; } // double UniNormLikelihood::cluster_lpdf_from_unconstrained( diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 84a63518c..37cd611ff 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -32,7 +32,7 @@ class LaplaceLikelihood for (auto it = cluster_data_values.begin(); it != cluster_data_values.end(); ++it) { out += stan::math::double_exponential_lpdf(*it, mean, - stan::math::sqrt(var / 2)); + stan::math::sqrt(var / 2.0)); } return out; } From 9ed4e99a026373a396e8d41ab96ae60aea864f61 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 1 Mar 2022 17:35:51 +0100 Subject: [PATCH 164/317] Add LapNIG hierarchy id --- src/proto/hierarchy_id.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/proto/hierarchy_id.proto b/src/proto/hierarchy_id.proto index e481f0871..6ab62ab93 100644 --- a/src/proto/hierarchy_id.proto +++ b/src/proto/hierarchy_id.proto @@ -11,6 +11,6 @@ enum HierarchyId { NNW = 2; // Normal - Normal Wishart LinRegUni = 3; // Linear Regression (univariate response) NNxIG = 4; // Normal - Normal x Inverse Gamma - // LapNIG = 5; // Laplace - Normal Inverse Gamma + LapNIG = 5; // Laplace - Normal Inverse Gamma // FA = 6; // Factor Analysers } From d2dd45811e761729d125efc51f0cedeed91d680e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 1 Mar 2022 17:36:32 +0100 Subject: [PATCH 165/317] Add tests for Laplace likelihood --- test/likelihoods.cc | 116 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 7 deletions(-) diff --git a/test/likelihoods.cc b/test/likelihoods.cc index bdcac6ff8..bea95a0b1 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -6,9 +6,9 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" -#include "src/hierarchies/likelihoods/uni_norm_likelihood.h" +#include "src/hierarchies/likelihoods/laplace_likelihood.h" #include "src/hierarchies/likelihoods/multi_norm_likelihood.h" - +#include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/utils/proto_utils.h" #include "src/utils/rng.h" @@ -139,7 +139,7 @@ TEST(multi_norm_likelihood, set_get_state) { bayesmix::AlgorithmState::ClusterState got_state_; // Prepare state - Eigen::Vector2d mu = {5.5, 5.5}; //mu << 5.5, 5.5; + Eigen::Vector2d mu = {5.5, 5.5}; // mu << 5.5, 5.5; Eigen::Matrix2d prec = Eigen::Matrix2d::Identity(); bayesmix::Vector mean_proto; bayesmix::Matrix prec_proto; @@ -147,7 +147,8 @@ TEST(multi_norm_likelihood, set_get_state) { bayesmix::to_proto(prec, &prec_proto); set_state_.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); set_state_.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); - set_state_.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom(prec_proto); + set_state_.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom( + prec_proto); // Set and get the state like->set_state_from_proto(set_state_); @@ -182,7 +183,7 @@ TEST(multi_norm_likelihood, eval_lpdf) { // Set state from proto bayesmix::AlgorithmState::ClusterState clust_state_; - Eigen::Vector2d mu = {5.5, 5.5}; //mu << 5.5, 5.5; + Eigen::Vector2d mu = {5.5, 5.5}; // mu << 5.5, 5.5; Eigen::Matrix2d prec = Eigen::Matrix2d::Identity(); bayesmix::Vector mean_proto; bayesmix::Matrix prec_proto; @@ -190,11 +191,12 @@ TEST(multi_norm_likelihood, eval_lpdf) { bayesmix::to_proto(prec, &prec_proto); clust_state_.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); clust_state_.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); - clust_state_.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom(prec_proto); + clust_state_.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom( + prec_proto); like->set_state_from_proto(clust_state_); // Data matrix on which evaluate the likelihood - Eigen::MatrixXd data(3,2); + Eigen::MatrixXd data(3, 2); data.row(0) << 4.5, 4.5; data.row(1) << 5.1, 5.1; data.row(2) << 2.5, 2.5; @@ -207,3 +209,103 @@ TEST(multi_norm_likelihood, eval_lpdf) { // Check if they coincides ASSERT_EQ(evals, evals_copy); } + +TEST(laplace_likelihood, set_get_state) { + // Instance + auto like = std::make_shared(); + + // Prepare buffers + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state + state_.set_mean(5.23); + state_.set_var(1.02); + set_state_.mutable_uni_ls_state()->CopyFrom(state_); + + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(laplace_likelihood, add_remove_data) { + // Instance + auto like = std::make_shared(); + + // Add new datum to likelihood + Eigen::VectorXd datum(1); + datum << 5.0; + like->add_datum(0, datum); + + // Check if cardinality is augmented + ASSERT_EQ(like->get_card(), 1); + + // Remove datum from likelihood + like->remove_datum(0, datum); + + // Check if cardinality is reduced + ASSERT_EQ(like->get_card(), 0); +} + +TEST(laplace_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + state_.set_mean(5); + state_.set_var(1); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 4.5, 5.1, 2.5; + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} + +TEST(laplace_likelihood, eval_lpdf_unconstrained) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::UniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + double mean = 5; + double var = 1; + state_.set_mean(mean); + state_.set_var(var); + Eigen::VectorXd unconstrained_params(2); + unconstrained_params << mean, std::log(var); + clust_state_.mutable_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 4.5, 5.1, 2.5; + double lpdf = 0.0; + for (int i = 0; i < data.size(); ++i) { + like->add_datum(i, data.row(i)); + lpdf += like->lpdf(data.row(i)); + } + + double clus_lpdf = + like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_DOUBLE_EQ(lpdf, clus_lpdf); + + unconstrained_params(0) = 3.0; + clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5); +} From e9f2315ae6e241b077e7d8d348195c29d0d967e1 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:04:00 +0100 Subject: [PATCH 166/317] Rename file --- ...lin_reg_state.h => uni_lin_reg_ls_state.h} | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) rename src/hierarchies/likelihoods/states/{uni_lin_reg_state.h => uni_lin_reg_ls_state.h} (65%) diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h similarity index 65% rename from src/hierarchies/likelihoods/states/uni_lin_reg_state.h rename to src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h index a97bfa1c0..fc7b015e9 100644 --- a/src/hierarchies/likelihoods/states/uni_lin_reg_state.h +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h @@ -1,14 +1,13 @@ -#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_STATE_H_ -#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_STATE_H_ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_LS_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_LS_STATE_H_ -#include #include #include #include "algorithm_state.pb.h" +#include "src/utils/eigen_utils.h" #include "src/utils/proto_utils.h" -// TODO: CHECK VECTOR ASSIGNMENTS AND POSITIONING! namespace State { template @@ -37,7 +36,7 @@ T uni_lin_reg_log_det_jac(Eigen::Matrix constrained) { return out; } -class UniLinReg { +class UniLinRegLS { public: Eigen::VectorXd regression_coeffs; double var; @@ -55,18 +54,20 @@ class UniLinReg { var = temp(dim); } - // void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) - // { - // mean = state_.uni_ls_state().mean(); - // var = state_.uni_ls_state().var(); - // } + void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + regression_coeffs = + bayesmix::to_eigen(state_.lin_reg_uni_ls_state().regression_coeffs()); + var = state_.lin_reg_uni_ls_state().var(); + } - // bayesmix::AlgorithmState::ClusterState get_as_proto() { - // bayesmix::AlgorithmState::ClusterState state; - // state.mutable_uni_ls_state()->set_mean(mean); - // state.mutable_uni_ls_state()->set_var(var); - // return state; - // } + bayesmix::AlgorithmState::ClusterState get_as_proto() { + bayesmix::LinRegUniLSState out; + bayesmix::to_proto(regression_coeffs, out.mutable_regression_coeffs()); + out.set_var(var); + bayesmix::AlgorithmState::ClusterState state; + state.mutable_lin_reg_uni_ls_state()->CopyFrom(out); + return state; + } double log_det_jac() { Eigen::VectorXd temp(regression_coeffs.size() + 1); @@ -77,4 +78,4 @@ class UniLinReg { } // namespace State -#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_STATE_H_ +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LIN_REG_LS_STATE_H_ From 156a613ffd93639c2d798774a1a26776e4911af4 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:04:46 +0100 Subject: [PATCH 167/317] Fixed get_like_lpdf() bug --- src/hierarchies/abstract_hierarchy.h | 2 +- src/hierarchies/base_hierarchy.h | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 8d7aa7485..fc9dc78a5 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -68,7 +68,7 @@ class AbstractHierarchy { // EVALUATION FUNCTIONS FOR SINGLE POINTS //! Public wrapper for `like_lpdf()` methods - double get_like_lpdf( + virtual double get_like_lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { if (is_dependent()) { diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index d8c0b9a26..24c1025b1 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -97,12 +97,6 @@ class BaseHierarchy : public AbstractHierarchy { return out; } - // NOT SURE THIS IS CORRECT, MAYBE OVERRIDE GET_LIKE_LPDF? OR THIS IS EVEN - // UNNECESSARY - double like_lpdf(const Eigen::RowVectorXd &datum) const override { - return like->lpdf(datum); - } - //! Returns an independent, data-less copy of this object // std::shared_ptr deep_clone() const override { // auto out = std::make_shared(static_castlpdf(datum, covariate); + } + //! Evaluates the log-likelihood of data in a grid of points //! @param data Grid of points (by row) which are to be evaluated //! @param covariates (Optional) covariate vectors associated to data From 3ad975d1a22e87ae5b4e62f4b3f6cbc5dc11c944 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:05:02 +0100 Subject: [PATCH 168/317] Add lin_reg_uni_hierarchy --- src/hierarchies/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index 3ff8486f7..e8f5e427e 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -7,7 +7,7 @@ target_sources(bayesmix nnw_hierarchy.h # nnw_hierarchy.cc # conjugate_hierarchy.h - # lin_reg_uni_hierarchy.h + lin_reg_uni_hierarchy.h # lin_reg_uni_hierarchy.cc # fa_hierarchy.h # fa_hierarchy.cc From f3dd35c09674d2b73f5db8df4337a309c9fcecdb Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:05:40 +0100 Subject: [PATCH 169/317] Improved exeption handling --- src/hierarchies/likelihoods/abstract_likelihood.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index ca7ef4403..936d93b18 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -19,7 +19,7 @@ class AbstractLikelihood { double lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - if (is_dependent()) { + if (is_dependent() and covariate.size() != 0) { return compute_lpdf(datum, covariate); } else { return compute_lpdf(datum); @@ -96,9 +96,10 @@ class AbstractLikelihood { virtual double compute_lpdf(const Eigen::RowVectorXd &datum) const { if (is_dependent()) { throw std::runtime_error( - "Cannot call this function from a dependent likelihood"); + "Cannot call compute_lpdf() from a dependent likelihood"); } else { - throw std::runtime_error("Not implemented"); + throw std::runtime_error( + "compute_lpdf() not implemented for this likelihood"); } } @@ -106,9 +107,10 @@ class AbstractLikelihood { const Eigen::RowVectorXd &covariate) const { if (!is_dependent()) { throw std::runtime_error( - "Cannot call this function from a non-dependent likelihood"); + "Cannot call compute_lpdf() from a non-dependent likelihood"); } else { - throw std::runtime_error("Not implemented"); + throw std::runtime_error( + "compute_lpdf() not implemented for this likelihood"); } } From dc3233ea0a82e9ee565334cb5ae5e6567edbbaca Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:06:17 +0100 Subject: [PATCH 170/317] Add MultiNormalInverseGamma prior model --- src/hierarchies/priors/mnig_prior_model.cc | 90 ++++++++++++++++++++++ src/hierarchies/priors/mnig_prior_model.h | 44 +++++++++++ 2 files changed, 134 insertions(+) create mode 100644 src/hierarchies/priors/mnig_prior_model.cc create mode 100644 src/hierarchies/priors/mnig_prior_model.h diff --git a/src/hierarchies/priors/mnig_prior_model.cc b/src/hierarchies/priors/mnig_prior_model.cc new file mode 100644 index 000000000..34041a30a --- /dev/null +++ b/src/hierarchies/priors/mnig_prior_model.cc @@ -0,0 +1,90 @@ +#include "mnig_prior_model.h" + +double MNIGPriorModel::lpdf(const google::protobuf::Message &state_) { + auto &state = downcast_state(state_).lin_reg_uni_ls_state(); + Eigen::VectorXd regression_coeffs = + bayesmix::to_eigen(state.regression_coeffs()); + double target = stan::math::multi_normal_prec_lpdf( + regression_coeffs, hypers->mean, hypers->var_scaling / state.var()); + target += + stan::math::inv_gamma_lpdf(state.var(), hypers->shape, hypers->scale); + return target; +} + +std::shared_ptr MNIGPriorModel::sample( + bool use_post_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + Hyperparams::MNIG params = use_post_hypers ? post_hypers : *hypers; + + double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); + Eigen::VectorXd regression_coeffs = stan::math::multi_normal_prec_rng( + params.mean, params.var_scaling / var, rng); + + bayesmix::AlgorithmState::ClusterState state; + // bayesmix::Vector regression_coeffs_proto; + bayesmix::to_proto( + regression_coeffs, + state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()); + // state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()->CopyFrom(regression_coeffs_proto); + state.mutable_lin_reg_uni_ls_state()->set_var(var); + + return std::make_shared(state); +} + +void MNIGPriorModel::update_hypers( + const std::vector &states) { + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void MNIGPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).lin_reg_uni_state(); + hypers->mean = bayesmix::to_eigen(hyperscast.mean()); + hypers->var_scaling = bayesmix::to_eigen(hyperscast.var_scaling()); + hypers->scale = hyperscast.scale(); + hypers->shape = hyperscast.shape(); +} + +std::shared_ptr +MNIGPriorModel::get_hypers_proto() const { + bayesmix::MultiNormalIGDistribution hypers_; + bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); + bayesmix::to_proto(hypers->var_scaling, hypers_.mutable_var_scaling()); + hypers_.set_shape(hypers->shape); + hypers_.set_scale(hypers->scale); + + auto out = std::make_shared(); + out->mutable_lin_reg_uni_state()->CopyFrom(hypers_); + return out; +} + +void MNIGPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); + dim = hypers->mean.size(); + hypers->var_scaling = + bayesmix::to_eigen(prior->fixed_values().var_scaling()); + hypers->var_scaling_inv = stan::math::inverse_spd(hypers->var_scaling); + hypers->shape = prior->fixed_values().shape(); + hypers->scale = prior->fixed_values().scale(); + // Check validity + if (dim != hypers->var_scaling.rows()) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + bayesmix::check_spd(hypers->var_scaling); + if (hypers->shape <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + if (hypers->scale <= 0) { + throw std::invalid_argument("scale parameter must be > 0"); + } + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h new file mode 100644 index 000000000..6502849d5 --- /dev/null +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -0,0 +1,44 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_MNIG_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_MNIG_PRIOR_MODEL_H_ + +// #include + +#include +#include +#include + +// #include "algorithm_state.pb.h" +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +class MNIGPriorModel : public BasePriorModel { + public: + MNIGPriorModel() = default; + ~MNIGPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + std::shared_ptr sample( + bool use_post_hypers) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + unsigned int get_dim() const { return dim; }; + + protected: + std::shared_ptr get_hypers_proto() + const override; + + void initialize_hypers() override; + + unsigned int dim; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_MNIG_PRIOR_MODEL_H_ From 28419a2231ad52a2f169b1d689091204ccaf130b Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:06:47 +0100 Subject: [PATCH 171/317] Add MultiNormalInverseGamma updater --- src/hierarchies/updaters/mnig_updater.cc | 39 ++++++++++++++++++++++++ src/hierarchies/updaters/mnig_updater.h | 18 +++++++++++ 2 files changed, 57 insertions(+) create mode 100644 src/hierarchies/updaters/mnig_updater.cc create mode 100644 src/hierarchies/updaters/mnig_updater.h diff --git a/src/hierarchies/updaters/mnig_updater.cc b/src/hierarchies/updaters/mnig_updater.cc new file mode 100644 index 000000000..2d1280094 --- /dev/null +++ b/src/hierarchies/updaters/mnig_updater.cc @@ -0,0 +1,39 @@ +#include "mnig_updater.h" + +void MNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + int card = likecast.get_card(); + unsigned int dim = likecast.get_dim(); + double data_sum_squares = likecast.get_data_sum_squares(); + Eigen::MatrixXd covar_sum_squares = likecast.get_covar_sum_squares(); + Eigen::MatrixXd mixed_prod = likecast.get_mixed_prod(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + priorcast.set_posterior_hypers(hypers); + return; + } + + // Compute posterior hyperparameters + Hyperparams::MNIG post_params; + post_params.var_scaling = covar_sum_squares + hypers.var_scaling; + auto llt = post_params.var_scaling.llt(); + post_params.var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, dim)); + post_params.mean = llt.solve(mixed_prod + hypers.var_scaling * hypers.mean); + post_params.shape = hypers.shape + 0.5 * card; + post_params.scale = + hypers.scale + + 0.5 * (data_sum_squares + + hypers.mean.transpose() * hypers.var_scaling * hypers.mean - + post_params.mean.transpose() * post_params.var_scaling * + post_params.mean); + + priorcast.set_posterior_hypers(post_params); + return; +}; diff --git a/src/hierarchies/updaters/mnig_updater.h b/src/hierarchies/updaters/mnig_updater.h new file mode 100644 index 000000000..6a58174bd --- /dev/null +++ b/src/hierarchies/updaters/mnig_updater.h @@ -0,0 +1,18 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ + +#include "conjugate_updater.h" +#include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h" +#include "src/hierarchies/priors/mnig_prior_model.h" + +class MNIGUpdater + : public ConjugateUpdater { + public: + MNIGUpdater() = default; + ~MNIGUpdater() = default; + + void compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) override; +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ From b8da1d57735d69d42980f9fccbf6851c410d4cd8 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:07:09 +0100 Subject: [PATCH 172/317] Add mnig_updater --- src/hierarchies/updaters/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 85d835a1f..1a3dd4d82 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -12,4 +12,6 @@ target_sources(bayesmix PUBLIC nnxig_updater.cc nnw_updater.h nnw_updater.cc + mnig_updater.h + mnig_updater.cc ) From 090030c2c7a1cf29f89a3d28fa72a187bfce90af Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:07:21 +0100 Subject: [PATCH 173/317] Add mnig_prior_model --- src/hierarchies/priors/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt index 5bb705f36..5b218e6c0 100644 --- a/src/hierarchies/priors/CMakeLists.txt +++ b/src/hierarchies/priors/CMakeLists.txt @@ -8,4 +8,6 @@ target_sources(bayesmix PUBLIC nxig_prior_model.cc nw_prior_model.h nw_prior_model.cc + mnig_prior_model.h + mnig_prior_model.cc ) From 0758d31e1a899dc5916e74a2ccdfbb263ec777c6 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:07:49 +0100 Subject: [PATCH 174/317] Clean useless includes --- src/hierarchies/likelihoods/states/multi_ls_state.h | 1 - src/hierarchies/likelihoods/states/uni_ls_state.h | 1 - 2 files changed, 2 deletions(-) diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index 406d008fe..b83970847 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -1,7 +1,6 @@ #ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ #define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_MULTI_LS_STATE_H_ -// #include #include #include diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h index c7b242ba5..a347c54cd 100644 --- a/src/hierarchies/likelihoods/states/uni_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -1,7 +1,6 @@ #ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ #define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ -#include #include #include "algorithm_state.pb.h" From b391f9c637fcff0f2380c348cd026726df1c1884 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:08:17 +0100 Subject: [PATCH 175/317] Add UniLinReg likelihood --- .../likelihoods/uni_lin_reg_likelihood.cc | 50 +++++++++++++++++ .../likelihoods/uni_lin_reg_likelihood.h | 55 +++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc create mode 100644 src/hierarchies/likelihoods/uni_lin_reg_likelihood.h diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc new file mode 100644 index 000000000..a4ddfaab9 --- /dev/null +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc @@ -0,0 +1,50 @@ +#include "uni_lin_reg_likelihood.h" + +#include "src/utils/eigen_utils.h" + +void UniLinRegLikelihood::set_state_from_proto( + const google::protobuf::Message &state_, bool update_card) { + auto &statecast = downcast_state(state_); + state.regression_coeffs = + bayesmix::to_eigen(statecast.lin_reg_uni_ls_state().regression_coeffs()); + state.var = statecast.lin_reg_uni_ls_state().var(); + if (update_card) set_card(statecast.cardinality()); +} + +void UniLinRegLikelihood::clear_summary_statistics() { + mixed_prod = Eigen::VectorXd::Zero(dim); + data_sum_squares = 0.0; + covar_sum_squares = Eigen::MatrixXd::Zero(dim, dim); +} + +std::shared_ptr +UniLinRegLikelihood::get_state_proto() const { + bayesmix::LinRegUniLSState state_; + bayesmix::to_proto(state.regression_coeffs, + state_.mutable_regression_coeffs()); + state_.set_var(state.var); + auto out = std::make_shared(); + out->mutable_lin_reg_uni_ls_state()->CopyFrom(state_); + return out; +} + +double UniLinRegLikelihood::compute_lpdf( + const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const { + return stan::math::normal_lpdf( + datum(0), state.regression_coeffs.dot(covariate), sqrt(state.var)); +} + +void UniLinRegLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) { + if (add) { + data_sum_squares += datum(0) * datum(0); + covar_sum_squares += covariate.transpose() * covariate; + mixed_prod += datum(0) * covariate.transpose(); + } else { + data_sum_squares -= datum(0) * datum(0); + covar_sum_squares -= covariate.transpose() * covariate; + mixed_prod -= datum(0) * covariate.transpose(); + } +} diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h new file mode 100644 index 000000000..25f559e66 --- /dev/null +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h @@ -0,0 +1,55 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_LIN_REG_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_LIN_REG_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +class UniLinRegLikelihood + : public BaseLikelihood { + public: + UniLinRegLikelihood() = default; + ~UniLinRegLikelihood() = default; + bool is_multivariate() const override { return false; }; + bool is_dependent() const override { return true; }; + void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) override; + void clear_summary_statistics() override; + + // Getters and Setters + unsigned int get_dim() const { return dim; }; + void set_dim(unsigned int dim_) { + dim = dim_; + clear_summary_statistics(); + }; + double get_data_sum_squares() const { return data_sum_squares; }; + Eigen::MatrixXd get_covar_sum_squares() const { return covar_sum_squares; }; + Eigen::VectorXd get_mixed_prod() const { return mixed_prod; }; + + std::shared_ptr get_state_proto() + const override; + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate, + bool add) override; + + //! Dimension of the coefficients vector + unsigned int dim; + //! Represents pieces of y^t y + double data_sum_squares; + //! Represents pieces of X^T X + Eigen::MatrixXd covar_sum_squares; + //! Represents pieces of X^t y + Eigen::VectorXd mixed_prod; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_UNI_LIN_REG_LIKELIHOOD_H_ From 1e95d408c0f9d9782534e11ec38c025cc50fd4c4 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:08:45 +0100 Subject: [PATCH 176/317] Add tests for uni_lin_reg_likelihood --- test/likelihoods.cc | 79 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/test/likelihoods.cc b/test/likelihoods.cc index bea95a0b1..bd52f7b26 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -8,6 +8,7 @@ #include "ls_state.pb.h" #include "src/hierarchies/likelihoods/laplace_likelihood.h" #include "src/hierarchies/likelihoods/multi_norm_likelihood.h" +#include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h" #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/utils/proto_utils.h" #include "src/utils/rng.h" @@ -210,6 +211,84 @@ TEST(multi_norm_likelihood, eval_lpdf) { ASSERT_EQ(evals, evals_copy); } +TEST(uni_lin_reg_likelihood, set_get_state) { + // Instance + auto like = std::make_shared(); + + // Prepare buffers + bayesmix::LinRegUniLSState state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state + Eigen::Vector3d reg_coeffs; + reg_coeffs << 2.25, 0.22, -7.1; + bayesmix::Vector reg_coeffs_proto; + bayesmix::to_proto(reg_coeffs, ®_coeffs_proto); + state_.mutable_regression_coeffs()->CopyFrom(reg_coeffs_proto); + state_.set_var(1.02); + set_state_.mutable_lin_reg_uni_ls_state()->CopyFrom(state_); + + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(uni_lin_reg_likelihood, add_remove_data) { + // Instance + auto like = std::make_shared(); + + // Add new datum to likelihood + Eigen::VectorXd datum(1); + datum << 5.0; + like->add_datum(0, datum); + + // Check if cardinality is augmented + ASSERT_EQ(like->get_card(), 1); + + // Remove datum from likelihood + like->remove_datum(0, datum); + + // Check if cardinality is reduced + ASSERT_EQ(like->get_card(), 0); +} + +TEST(uni_lin_reg_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::LinRegUniLSState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + Eigen::Vector3d reg_coeffs; + reg_coeffs << 2.25, 0.22, -7.1; + bayesmix::Vector reg_coeffs_proto; + bayesmix::to_proto(reg_coeffs, ®_coeffs_proto); + state_.mutable_regression_coeffs()->CopyFrom(reg_coeffs_proto); + state_.set_var(1.02); + clust_state_.mutable_lin_reg_uni_ls_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Generate data + Eigen::Vector3d data; + data << 4.5, 5.1, 2.5; + + // Generate random covariate matrix + Eigen::MatrixXd cov = + Eigen::MatrixXd::Random(data.size(), reg_coeffs.size()); + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data, cov); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data, cov); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} + TEST(laplace_likelihood, set_get_state) { // Instance auto like = std::make_shared(); From 066699ec41ecc9c483035a40a1b613e19c5d7a8c Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:09:02 +0100 Subject: [PATCH 177/317] Uncomment old tests --- test/lpdf.cc | 113 ++++++++++++++++++++++++++------------------------- 1 file changed, 57 insertions(+), 56 deletions(-) diff --git a/test/lpdf.cc b/test/lpdf.cc index 90512e20c..0f41f69ba 100644 --- a/test/lpdf.cc +++ b/test/lpdf.cc @@ -5,7 +5,7 @@ #include #include "algorithm_state.pb.h" -// #include "src/hierarchies/lin_reg_uni_hierarchy.h" +#include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" // #include "src/hierarchies/nnw_hierarchy.h" #include "src/utils/proto_utils.h" @@ -137,62 +137,63 @@ TEST(lpdf, nnig) { // ASSERT_DOUBLE_EQ(marg, marg_murphy); // } -// TEST(lpdf, lin_reg_uni) { -// // Create hierarchy objects -// LinRegUniHierarchy hier; -// bayesmix::LinRegUniPrior prior; -// int dim = 3; - -// // Generate data -// Eigen::VectorXd datum(1); -// datum << 1.5; -// Eigen::VectorXd cov = Eigen::VectorXd::Random(dim); - -// // Create parameters, both Eigen and proto -// Eigen::VectorXd mu0(dim); -// for (int i = 0; i < dim; i++) { -// mu0(i) = 2 * i; -// } -// bayesmix::Vector mu0_proto; -// bayesmix::to_proto(mu0, &mu0_proto); -// auto Lambda0 = Eigen::MatrixXd::Identity(dim, dim); -// bayesmix::Matrix Lambda0_proto; -// bayesmix::to_proto(Lambda0, &Lambda0_proto); -// double alpha0 = 2.0; -// double beta0 = 2.0; -// // Set parameters -// *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; -// *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; -// prior.mutable_fixed_values()->set_shape(alpha0); -// prior.mutable_fixed_values()->set_scale(beta0); -// // Initialize hierarchy -// hier.get_mutable_prior()->CopyFrom(prior); -// hier.initialize(); +TEST(lpdf, lin_reg_uni) { + // Create hierarchy objects + LinRegUniHierarchy hier; + bayesmix::LinRegUniPrior prior; + int dim = 3; -// // Compute prior parameters -// Eigen::VectorXd mean = mu0; -// double var = beta0 / (alpha0 + 1); + // Generate data + Eigen::VectorXd datum(1); + datum << 1.5; + Eigen::VectorXd cov = Eigen::VectorXd::Random(dim); + + // Create parameters, both Eigen and proto + Eigen::VectorXd mu0(dim); + for (int i = 0; i < dim; i++) { + mu0(i) = 2 * i; + } + bayesmix::Vector mu0_proto; + bayesmix::to_proto(mu0, &mu0_proto); + auto Lambda0 = Eigen::MatrixXd::Identity(dim, dim); + bayesmix::Matrix Lambda0_proto; + bayesmix::to_proto(Lambda0, &Lambda0_proto); + double alpha0 = 2.0; + double beta0 = 2.0; + // Set parameters + *prior.mutable_fixed_values()->mutable_mean() = mu0_proto; + *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; + prior.mutable_fixed_values()->set_shape(alpha0); + prior.mutable_fixed_values()->set_scale(beta0); + // Initialize hierarchy + hier.get_mutable_prior()->CopyFrom(prior); + hier.initialize(); -// // Compute posterior parameters -// Eigen::MatrixXd Lambda_n = Lambda0 + cov * cov.transpose(); -// Eigen::VectorXd mu_n = -// stan::math::inverse_spd(Lambda_n) * (datum(0) * cov + Lambda0 * mu0); -// double alpha_n = alpha0 + 0.5; -// double beta_n = -// beta0 + 0.5 * (datum(0) * datum(0) + mu0.transpose() * Lambda0 * mu0 - -// mu_n.transpose() * Lambda_n * mu_n); -// // Compute pieces -// double prior1 = stan::math::inv_gamma_lpdf(var, alpha0, beta0); -// double prior2 = stan::math::multi_normal_prec_lpdf(mean, mu0, Lambda0 / -// var); double pr = prior1 + prior2; double like = hier.get_like_lpdf(datum, -// cov); double post1 = stan::math::inv_gamma_lpdf(var, alpha_n, beta_n); -// double post2 = -// stan::math::multi_normal_prec_lpdf(mean, mu_n, Lambda_n / var); -// double post = post1 + post2; + // Compute prior parameters + Eigen::VectorXd mean = mu0; + double var = beta0 / (alpha0 + 1); -// // Bayes: logmarg(x) = logprior(phi) + loglik(x|phi) - logpost(phi|x) -// double sum = pr + like - post; -// double marg = hier.prior_pred_lpdf(datum, cov); + // Compute posterior parameters + Eigen::MatrixXd Lambda_n = Lambda0 + cov * cov.transpose(); + Eigen::VectorXd mu_n = + stan::math::inverse_spd(Lambda_n) * (datum(0) * cov + Lambda0 * mu0); + double alpha_n = alpha0 + 0.5; + double beta_n = + beta0 + 0.5 * (datum(0) * datum(0) + mu0.transpose() * Lambda0 * mu0 - + mu_n.transpose() * Lambda_n * mu_n); + // Compute pieces + double prior1 = stan::math::inv_gamma_lpdf(var, alpha0, beta0); + double prior2 = stan::math::multi_normal_prec_lpdf(mean, mu0, Lambda0 / var); + double pr = prior1 + prior2; + double like = hier.get_like_lpdf(datum, cov); + double post1 = stan::math::inv_gamma_lpdf(var, alpha_n, beta_n); + double post2 = + stan::math::multi_normal_prec_lpdf(mean, mu_n, Lambda_n / var); + double post = post1 + post2; -// ASSERT_FLOAT_EQ(sum, marg); -// } + // Bayes: logmarg(x) = logprior(phi) + loglik(x|phi) - logpost(phi|x) + double sum = pr + like - post; + double marg = hier.prior_pred_lpdf(datum, cov); + + ASSERT_FLOAT_EQ(sum, marg); +} From c1ab566d71fdb10cbc8f78f9a7af78060743be1e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:09:28 +0100 Subject: [PATCH 178/317] Add tests for mnig_prior_model --- test/prior_models.cc | 114 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 7 deletions(-) diff --git a/test/prior_models.cc b/test/prior_models.cc index 1201e9c88..2374ae73d 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -6,10 +6,10 @@ #include "algorithm_state.pb.h" #include "hierarchy_prior.pb.h" - +#include "src/hierarchies/priors/mnig_prior_model.h" #include "src/hierarchies/priors/nig_prior_model.h" -#include "src/hierarchies/priors/nxig_prior_model.h" #include "src/hierarchies/priors/nw_prior_model.h" +#include "src/hierarchies/priors/nxig_prior_model.h" #include "src/utils/proto_utils.h" TEST(nig_prior_model, set_get_hypers) { @@ -321,8 +321,14 @@ TEST(nw_prior_model, normal_mean_prior) { bayesmix::to_proto(mu00, &mu00_proto); bayesmix::to_proto(Sigma00, &Sigma00_proto); bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale_proto); - prior.mutable_normal_mean_prior()->mutable_mean_prior()->mutable_mean()->CopyFrom(mu00_proto); - prior.mutable_normal_mean_prior()->mutable_mean_prior()->mutable_var()->CopyFrom(Sigma00_proto); + prior.mutable_normal_mean_prior() + ->mutable_mean_prior() + ->mutable_mean() + ->CopyFrom(mu00_proto); + prior.mutable_normal_mean_prior() + ->mutable_mean_prior() + ->mutable_var() + ->CopyFrom(Sigma00_proto); prior.mutable_normal_mean_prior()->set_var_scaling(0.1); prior.mutable_normal_mean_prior()->set_deg_free(4); prior.mutable_normal_mean_prior()->mutable_scale()->CopyFrom(scale_proto); @@ -331,10 +337,12 @@ TEST(nw_prior_model, normal_mean_prior) { std::vector states(4); for (int i = 0; i < states.size(); i++) { Eigen::Vector2d mean = (9.0 + i) * Eigen::Vector2d::Ones(); - bayesmix::Vector tmp; bayesmix::to_proto(mean, &tmp); + bayesmix::Vector tmp; + bayesmix::to_proto(mean, &tmp); states[i].mutable_multi_ls_state()->mutable_mean()->CopyFrom(tmp); states[i].mutable_multi_ls_state()->mutable_prec()->CopyFrom(scale_proto); - states[i].mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom(scale_proto); + states[i].mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom( + scale_proto); } // Initialize prior model @@ -362,7 +370,7 @@ TEST(nw_prior_model, sample) { bayesmix::AlgorithmState::HierarchyHypers hypers_proto; bayesmix::Vector mean; bayesmix::Matrix scale; - bayesmix::to_proto(Eigen::Vector2d({5.2,5.2}), &mean); + bayesmix::to_proto(Eigen::Vector2d({5.2, 5.2}), &mean); bayesmix::to_proto(Eigen::Matrix2d::Identity(), &scale); hypers_proto.mutable_nnw_state()->mutable_mean()->CopyFrom(mean); hypers_proto.mutable_nnw_state()->set_var_scaling(0.1); @@ -377,3 +385,95 @@ TEST(nw_prior_model, sample) { // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); } + +TEST(mnig_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::MultiNormalIGDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + // Prepare hypers + Eigen::Vector2d mean({2.0, 2.0}); + bayesmix::to_proto(mean, hypers_.mutable_mean()); + Eigen::Matrix2d var_scaling = Eigen::Matrix2d::Identity(); + bayesmix::to_proto(var_scaling, hypers_.mutable_var_scaling()); + hypers_.set_shape(4.0); + hypers_.set_scale(3.0); + set_state_.mutable_lin_reg_uni_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(mnig_prior_model, fixed_values_prior) { + // Prepare buffers + bayesmix::LinRegUniPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior + Eigen::Vector2d mean({2.0, 2.0}); + bayesmix::to_proto(mean, prior.mutable_fixed_values()->mutable_mean()); + Eigen::Matrix2d var_scaling = Eigen::Matrix2d::Identity(); + bayesmix::to_proto(var_scaling, + prior.mutable_fixed_values()->mutable_var_scaling()); + prior.mutable_fixed_values()->set_shape(4.0); + prior.mutable_fixed_values()->set_scale(3.0); + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.lin_reg_uni_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.lin_reg_uni_state().DebugString()); + } +} + +TEST(mnig_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + bool use_post_hypers = true; + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + Eigen::Vector2d mean({5.0, 5.0}); + bayesmix::to_proto(mean, + hypers_proto.mutable_lin_reg_uni_state()->mutable_mean()); + Eigen::Matrix2d var_scaling = Eigen::Matrix2d::Identity(); + bayesmix::to_proto( + var_scaling, + hypers_proto.mutable_lin_reg_uni_state()->mutable_var_scaling()); + hypers_proto.mutable_lin_reg_uni_state()->set_shape(4.0); + hypers_proto.mutable_lin_reg_uni_state()->set_scale(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(!use_post_hypers); + auto state2 = prior->sample(!use_post_hypers); + + // Check if they coincides + ASSERT_TRUE(state1->DebugString() != state2->DebugString()); +} From 3655123f83802c44ea6bd738564e226eedba2658 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:09:40 +0100 Subject: [PATCH 179/317] Change file name --- src/hierarchies/likelihoods/states/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt index fa3d87c32..3c5a67426 100644 --- a/src/hierarchies/likelihoods/states/CMakeLists.txt +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -1,6 +1,6 @@ target_sources(bayesmix PUBLIC uni_ls_state.h multi_ls_state.h - uni_lin_reg_state.h + uni_lin_reg_ls_state.h includes.h ) From 78b284aecd15bb54c67024cbbe84e4cd15921a71 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:09:51 +0100 Subject: [PATCH 180/317] Improved code --- src/hierarchies/likelihoods/multi_norm_likelihood.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index f936a4c5a..0ffc84cf3 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -25,8 +25,7 @@ class MultiNormLikelihood void set_dim(unsigned int dim_) { dim = dim_; - data_sum = Eigen::VectorXd::Zero(dim); - data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); + clear_summary_statistics(); }; unsigned int get_dim() const { return dim; }; Eigen::VectorXd get_data_sum() const { return data_sum; }; From 8fe7913f5fb3bf5fa1f223d414996c1d7ec35a4e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:10:09 +0100 Subject: [PATCH 181/317] Add LinRegUniHierarchy to load --- src/hierarchies/load_hierarchies.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index 602487b80..eb87cac63 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -8,7 +8,7 @@ // #include "fa_hierarchy.h" #include "hierarchy_id.pb.h" #include "lapnig_hierarchy.h" -// #include "lin_reg_uni_hierarchy.h" +#include "lin_reg_uni_hierarchy.h" #include "nnig_hierarchy.h" #include "nnw_hierarchy.h" #include "nnxig_hierarchy.h" @@ -34,9 +34,9 @@ __attribute__((constructor)) static void load_hierarchies() { Builder NNWbuilder = []() { return std::make_shared(); }; - // Builder LinRegUnibuilder = []() { - // return std::make_shared(); - // }; + Builder LinRegUnibuilder = []() { + return std::make_shared(); + }; // Builder FAbuilder = []() { // return std::make_shared(); // }; @@ -47,7 +47,7 @@ __attribute__((constructor)) static void load_hierarchies() { factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder); factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); - // factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); + factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); // factory.add_builder(FAHierarchy().get_id(), FAbuilder); factory.add_builder(LapNIGHierarchy().get_id(), LapNIGbuilder); } From 93cc417cae0675741f3a0d2e66ef3df9b3bc8313 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:10:36 +0100 Subject: [PATCH 182/317] Re-factor LinRegUniHierarchy --- src/hierarchies/lin_reg_uni_hierarchy.h | 57 +++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/hierarchies/lin_reg_uni_hierarchy.h diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h new file mode 100644 index 000000000..a76d44f58 --- /dev/null +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -0,0 +1,57 @@ +#ifndef BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ + +// #include + +// #include +// #include +// #include + +// #include "algorithm_state.pb.h" +// #include "conjugate_hierarchy.h" +#include "hierarchy_id.pb.h" +// #include "hierarchy_prior.pb.h" + +#include "base_hierarchy.h" +#include "likelihoods/uni_lin_reg_likelihood.h" +#include "priors/mnig_prior_model.h" +#include "updaters/mnig_updater.h" + +class LinRegUniHierarchy + : public BaseHierarchy { + public: + ~LinRegUniHierarchy() = default; + + using BaseHierarchy::BaseHierarchy; + + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::LinRegUni; + } + + void set_default_updater() { updater = std::make_shared(); } + + void initialize_state() override { + // Initialize likelihood dimension to prior one + like->set_dim(prior->get_dim()); + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::UniLinRegLS state; + state.regression_coeffs = hypers.mean; + state.var = hypers.scale / (hypers.shape + 1); + like->set_state(state); + }; + + double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum, + const Eigen::RowVectorXd &covariate) const override { + double sig_n = sqrt( + (1 + (covariate * params.var_scaling_inv * covariate.transpose())(0)) * + params.scale / params.shape); + return stan::math::student_t_lpdf(datum(0), 2 * params.shape, + covariate.dot(params.mean), sig_n); + } +}; + +#endif // BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ From d71821efcefab249abef9a45adb537fbb85e7c4a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:10:57 +0100 Subject: [PATCH 183/317] Change file name --- src/hierarchies/likelihoods/states/includes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/likelihoods/states/includes.h b/src/hierarchies/likelihoods/states/includes.h index d0c921f6f..b1282fb6e 100644 --- a/src/hierarchies/likelihoods/states/includes.h +++ b/src/hierarchies/likelihoods/states/includes.h @@ -2,7 +2,7 @@ #define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ #include "multi_ls_state.h" -#include "uni_lin_reg_state.h" +#include "uni_lin_reg_ls_state.h" #include "uni_ls_state.h" #endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ From 24b96bdc654857332b30d6c7deb0f96b99ab098a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:11:16 +0100 Subject: [PATCH 184/317] Add uni_lin_reg_likelihood --- src/hierarchies/likelihoods/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index b9a6a25fa..42cc97da3 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -5,6 +5,8 @@ target_sources(bayesmix PUBLIC uni_norm_likelihood.cc multi_norm_likelihood.h multi_norm_likelihood.cc + uni_lin_reg_likelihood.h + uni_lin_reg_likelihood.cc laplace_likelihood.h laplace_likelihood.cc ) From 94ac41362f696abb18d15785ac87dc8fdaf17a8d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:11:35 +0100 Subject: [PATCH 185/317] Uncomment tests for LinRegUniHierarchy --- test/hierarchies.cc | 164 ++++++++++++++++++++++---------------------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/test/hierarchies.cc b/test/hierarchies.cc index 512c10847..027457a69 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -5,7 +5,7 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" -// #include "src/hierarchies/lin_reg_uni_hierarchy.h" +#include "src/hierarchies/lin_reg_uni_hierarchy.h" // #include "src/hierarchies/fa_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" #include "src/hierarchies/nnw_hierarchy.h" @@ -163,93 +163,93 @@ TEST(nnw_hierarchy, no_unconstrained_lpdf) { EXPECT_ANY_THROW(hier->get_prior()->lpdf_from_unconstrained(state_uc)); } -// TEST(lin_reg_uni_hierarchy, state_read_write) { -// Eigen::Vector2d beta; -// beta << 2, -1; -// double sigma2 = 9; +TEST(lin_reg_uni_hierarchy, state_read_write) { + Eigen::Vector2d beta; + beta << 2, -1; + double sigma2 = 9; -// bayesmix::LinRegUniLSState ls; -// bayesmix::to_proto(beta, ls.mutable_regression_coeffs()); -// ls.set_var(sigma2); + bayesmix::LinRegUniLSState ls; + bayesmix::to_proto(beta, ls.mutable_regression_coeffs()); + ls.set_var(sigma2); -// bayesmix::AlgorithmState::ClusterState state; -// state.mutable_lin_reg_uni_ls_state()->CopyFrom(ls); + bayesmix::AlgorithmState::ClusterState state; + state.mutable_lin_reg_uni_ls_state()->CopyFrom(ls); -// LinRegUniHierarchy hier; -// hier.set_state_from_proto(state); + LinRegUniHierarchy hier; + hier.set_state_from_proto(state); -// ASSERT_EQ(hier.get_state().regression_coeffs, beta); -// ASSERT_EQ(hier.get_state().var, sigma2); + ASSERT_EQ(hier.get_state().regression_coeffs, beta); + ASSERT_EQ(hier.get_state().var, sigma2); -// bayesmix::AlgorithmState outt; -// bayesmix::AlgorithmState::ClusterState* out = outt.add_cluster_states(); -// hier.write_state_to_proto(out); -// ASSERT_EQ(beta, bayesmix::to_eigen( -// out->lin_reg_uni_ls_state().regression_coeffs())); -// ASSERT_EQ(sigma2, out->lin_reg_uni_ls_state().var()); -// } + bayesmix::AlgorithmState outt; + bayesmix::AlgorithmState::ClusterState* out = outt.add_cluster_states(); + hier.write_state_to_proto(out); + ASSERT_EQ(beta, bayesmix::to_eigen( + out->lin_reg_uni_ls_state().regression_coeffs())); + ASSERT_EQ(sigma2, out->lin_reg_uni_ls_state().var()); +} -// TEST(lin_reg_uni_hierarchy, misc) { -// // Build data -// int n = 5; -// int dim = 2; -// Eigen::Vector2d beta_true; -// beta_true << 10.0, 10.0; -// Eigen::MatrixXd cov = Eigen::MatrixXd::Random(n, dim); // each in U[-1,1] -// double sigma2 = 1.0; -// Eigen::VectorXd data(n); -// auto& rng = bayesmix::Rng::Instance().get(); -// for (int i = 0; i < n; i++) { -// data(i) = stan::math::normal_rng(cov.row(i).dot(beta_true), sigma2, -// rng); -// } -// // Initialize objects -// LinRegUniHierarchy hier; -// bayesmix::LinRegUniPrior prior; -// // Create prior parameters -// Eigen::Vector2d beta0 = 0 * beta_true; -// bayesmix::Vector beta0_proto; -// bayesmix::to_proto(beta0, &beta0_proto); -// auto Lambda0 = Eigen::Matrix2d::Identity(); -// bayesmix::Matrix Lambda0_proto; -// bayesmix::to_proto(Lambda0, &Lambda0_proto); -// double a0 = 2.0; -// double b0 = 1.0; -// // Initialize hierarchy -// *prior.mutable_fixed_values()->mutable_mean() = beta0_proto; -// *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; -// prior.mutable_fixed_values()->set_shape(a0); -// prior.mutable_fixed_values()->set_scale(b0); -// hier.get_mutable_prior()->CopyFrom(prior); -// hier.initialize(); -// // Extract hypers for reading test -// bayesmix::AlgorithmState::HierarchyHypers out; -// hier.write_hypers_to_proto(&out); -// ASSERT_EQ(beta0, bayesmix::to_eigen(out.lin_reg_uni_state().mean())); -// ASSERT_EQ(Lambda0, -// bayesmix::to_eigen(out.lin_reg_uni_state().var_scaling())); -// ASSERT_EQ(a0, out.lin_reg_uni_state().shape()); -// ASSERT_EQ(b0, out.lin_reg_uni_state().scale()); -// // Add data -// for (int i = 0; i < n; i++) { -// hier.add_datum(i, data.row(i), false, cov.row(i)); -// } -// // Check summary statistics -// // for (int i = 0; i < dim; i++) { -// // for (int j = 0; j < dim; j++) { -// // ASSERT_DOUBLE_EQ(hier.get_covar_sum_squares()(i, j), -// // (cov.transpose() * cov)(i, j)); -// // } -// // ASSERT_DOUBLE_EQ(hier.get_mixed_prod()(i), (cov.transpose() * -// data)(i)); -// // } -// // Compute and check posterior values -// hier.sample_full_cond(); -// auto state = hier.get_state(); -// for (int i = 0; i < dim; i++) { -// ASSERT_GT(state.regression_coeffs(i), beta0(i)); -// } -// } +TEST(lin_reg_uni_hierarchy, misc) { + // Build data + int n = 5; + int dim = 2; + Eigen::Vector2d beta_true; + beta_true << 10.0, 10.0; + Eigen::MatrixXd cov = Eigen::MatrixXd::Random(n, dim); // each in U[-1,1] + double sigma2 = 1.0; + Eigen::VectorXd data(n); + auto& rng = bayesmix::Rng::Instance().get(); + for (int i = 0; i < n; i++) { + data(i) = stan::math::normal_rng(cov.row(i).dot(beta_true), sigma2, rng); + } + // Initialize objects + LinRegUniHierarchy hier; + bayesmix::LinRegUniPrior prior; + // Create prior parameters + Eigen::Vector2d beta0 = 0 * beta_true; + bayesmix::Vector beta0_proto; + bayesmix::to_proto(beta0, &beta0_proto); + auto Lambda0 = Eigen::Matrix2d::Identity(); + bayesmix::Matrix Lambda0_proto; + bayesmix::to_proto(Lambda0, &Lambda0_proto); + double a0 = 2.0; + double b0 = 1.0; + // Initialize hierarchy + *prior.mutable_fixed_values()->mutable_mean() = beta0_proto; + *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; + prior.mutable_fixed_values()->set_shape(a0); + prior.mutable_fixed_values()->set_scale(b0); + hier.get_mutable_prior()->CopyFrom(prior); + hier.initialize(); + // Extract hypers for reading test + bayesmix::AlgorithmState::HierarchyHypers out; + hier.write_hypers_to_proto(&out); + ASSERT_EQ(beta0, bayesmix::to_eigen(out.lin_reg_uni_state().mean())); + ASSERT_EQ(Lambda0, + bayesmix::to_eigen(out.lin_reg_uni_state().var_scaling())); + ASSERT_EQ(a0, out.lin_reg_uni_state().shape()); + ASSERT_EQ(b0, out.lin_reg_uni_state().scale()); + // Add data + for (int i = 0; i < n; i++) { + hier.add_datum(i, data.row(i), true, cov.row(i)); + } + + // Check summary statistics + // for (int i = 0; i < dim; i++) { + // for (int j = 0; j < dim; j++) { + // ASSERT_DOUBLE_EQ(hier.get_covar_sum_squares()(i, j), + // (cov.transpose() * cov)(i, j)); + // } + // ASSERT_DOUBLE_EQ(hier.get_mixed_prod()(i), (cov.transpose() * data)(i)); + // } + + // Compute and check posterior values + hier.sample_full_cond(); + auto state = hier.get_state(); + for (int i = 0; i < dim; i++) { + ASSERT_GT(state.regression_coeffs(i), beta0(i)); + } +} TEST(nnxig_hierarchy, draw) { auto hier = std::make_shared(); From 0152575295e2f5b2454e7d1d187a7f2f283620ec Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:12:01 +0100 Subject: [PATCH 186/317] Add lin_reg_uni_hierarchy --- src/includes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/includes.h b/src/includes.h index 0f8f51253..ae457596c 100644 --- a/src/includes.h +++ b/src/includes.h @@ -9,7 +9,7 @@ #include "algorithms/neal8_algorithm.h" #include "collectors/file_collector.h" #include "collectors/memory_collector.h" -// #include "hierarchies/lin_reg_uni_hierarchy.h" +#include "hierarchies/lin_reg_uni_hierarchy.h" // #include "hierarchies/lapnig_hierarchy.h" #include "hierarchies/load_hierarchies.h" // #include "hierarchies/fa_hierarchy.h" From 1de8fa09eb27082578b38b5edee5b3d2a5d4b00a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 4 Mar 2022 22:12:28 +0100 Subject: [PATCH 187/317] Add lin_reg_uni_hierarchy --- src/includes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/includes.h b/src/includes.h index ae457596c..03ea6dcb0 100644 --- a/src/includes.h +++ b/src/includes.h @@ -9,8 +9,8 @@ #include "algorithms/neal8_algorithm.h" #include "collectors/file_collector.h" #include "collectors/memory_collector.h" +#include "hierarchies/lapnig_hierarchy.h" #include "hierarchies/lin_reg_uni_hierarchy.h" -// #include "hierarchies/lapnig_hierarchy.h" #include "hierarchies/load_hierarchies.h" // #include "hierarchies/fa_hierarchy.h" #include "hierarchies/nnig_hierarchy.h" From 5e75eb2078a47a39b9e7f4b09419132a79e69de0 Mon Sep 17 00:00:00 2001 From: TeoGiane Date: Sat, 5 Mar 2022 15:07:26 +0100 Subject: [PATCH 188/317] Improved tests --- test/hierarchies.cc | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/test/hierarchies.cc b/test/hierarchies.cc index 027457a69..55e516f27 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -58,7 +58,7 @@ TEST(nnig_hierarchy, sample_given_data) { datum << 4.5; auto hier2 = hier->clone(); - hier2->add_datum(0, datum, true); + hier2->add_datum(0, datum, false); hier2->sample_full_cond(); bayesmix::AlgorithmState out; @@ -190,11 +190,10 @@ TEST(lin_reg_uni_hierarchy, state_read_write) { } TEST(lin_reg_uni_hierarchy, misc) { + // Build data - int n = 5; - int dim = 2; - Eigen::Vector2d beta_true; - beta_true << 10.0, 10.0; + int n = 5, dim = 2; + Eigen::Vector2d beta_true({10.0, 10.0}); Eigen::MatrixXd cov = Eigen::MatrixXd::Random(n, dim); // each in U[-1,1] double sigma2 = 1.0; Eigen::VectorXd data(n); @@ -202,25 +201,25 @@ TEST(lin_reg_uni_hierarchy, misc) { for (int i = 0; i < n; i++) { data(i) = stan::math::normal_rng(cov.row(i).dot(beta_true), sigma2, rng); } + // Initialize objects LinRegUniHierarchy hier; bayesmix::LinRegUniPrior prior; + // Create prior parameters Eigen::Vector2d beta0 = 0 * beta_true; - bayesmix::Vector beta0_proto; - bayesmix::to_proto(beta0, &beta0_proto); auto Lambda0 = Eigen::Matrix2d::Identity(); - bayesmix::Matrix Lambda0_proto; - bayesmix::to_proto(Lambda0, &Lambda0_proto); double a0 = 2.0; double b0 = 1.0; + // Initialize hierarchy - *prior.mutable_fixed_values()->mutable_mean() = beta0_proto; - *prior.mutable_fixed_values()->mutable_var_scaling() = Lambda0_proto; + bayesmix::to_proto(beta0, prior.mutable_fixed_values()->mutable_mean()); + bayesmix::to_proto(Lambda0, prior.mutable_fixed_values()->mutable_var_scaling()); prior.mutable_fixed_values()->set_shape(a0); prior.mutable_fixed_values()->set_scale(b0); hier.get_mutable_prior()->CopyFrom(prior); hier.initialize(); + // Extract hypers for reading test bayesmix::AlgorithmState::HierarchyHypers out; hier.write_hypers_to_proto(&out); @@ -229,9 +228,10 @@ TEST(lin_reg_uni_hierarchy, misc) { bayesmix::to_eigen(out.lin_reg_uni_state().var_scaling())); ASSERT_EQ(a0, out.lin_reg_uni_state().shape()); ASSERT_EQ(b0, out.lin_reg_uni_state().scale()); + // Add data for (int i = 0; i < n; i++) { - hier.add_datum(i, data.row(i), true, cov.row(i)); + hier.add_datum(i, data.row(i), false, cov.row(i)); } // Check summary statistics @@ -288,7 +288,7 @@ TEST(nnxig_hierarchy, sample_given_data) { datum << 4.5; auto hier2 = hier->clone(); - hier2->add_datum(0, datum, true); + hier2->add_datum(0, datum, false); hier2->sample_full_cond(); bayesmix::AlgorithmState out; From 4bd188ce47bd25c6bb96a540036a573c0c336640 Mon Sep 17 00:00:00 2001 From: TeoGiane Date: Sat, 5 Mar 2022 15:08:04 +0100 Subject: [PATCH 189/317] Fixed bug via workaround --- src/hierarchies/updaters/conjugate_updater.h | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/hierarchies/updaters/conjugate_updater.h b/src/hierarchies/updaters/conjugate_updater.h index b4f1cefe5..eb92e9ce1 100644 --- a/src/hierarchies/updaters/conjugate_updater.h +++ b/src/hierarchies/updaters/conjugate_updater.h @@ -42,14 +42,15 @@ void ConjugateUpdater::draw(AbstractLikelihood& like, auto& priorcast = downcast_prior(prior); // Sample from the full conditional of a conjugate hierarchy - bool set_card = true; + bool set_card = true, use_post_hypers=true; if (likecast.get_card() == 0) { - likecast.set_state_from_proto(*priorcast.sample(false), !set_card); + likecast.set_state_from_proto(*priorcast.sample(!use_post_hypers), !set_card); } else { - if (update_params) { - compute_posterior_hypers(likecast, priorcast); - } - likecast.set_state_from_proto(*prior.sample(true), !set_card); + auto prev_hypers = priorcast.get_posterior_hypers(); + compute_posterior_hypers(likecast, priorcast); + likecast.set_state_from_proto(*priorcast.sample(use_post_hypers), !set_card); + if (!update_params) + priorcast.set_posterior_hypers(prev_hypers); } } From 9717021ba96a4c84010952bda068470d398e2df4 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 8 Mar 2022 17:12:32 +0100 Subject: [PATCH 190/317] Improve documentation --- src/hierarchies/abstract_hierarchy.h | 37 ++++-- src/hierarchies/base_hierarchy.h | 112 ++---------------- .../likelihoods/abstract_likelihood.h | 3 + src/hierarchies/likelihoods/base_likelihood.h | 7 ++ src/hierarchies/priors/abstract_prior_model.h | 5 + src/hierarchies/priors/base_prior_model.h | 13 ++ 6 files changed, 70 insertions(+), 107 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index fc9dc78a5..11bff7549 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -51,13 +51,27 @@ class AbstractHierarchy { public: + //! Set the likelihood for the current hierarchy. Implemented in the + //! BaseHierarchy class virtual void set_likelihood(std::shared_ptr like_) = 0; + + //! Set the prior model for the current hierarchy. Implemented in the + //! BaseHierarchy class virtual void set_prior(std::shared_ptr prior_) = 0; + + //! Set the update algorithm for the current hierarchy. Implemented in the + //! BaseHierarchy class virtual void set_updater(std::shared_ptr updater_) = 0; + //! Returns (a pointer to) the likelihood for the current hierarchy. + //! Implemented in the BaseHierarchy class virtual std::shared_ptr get_likelihood() = 0; + + //! Returns (a pointer to) the prior model for the current hierarchy. + //! Implemented in the BaseHierarchy class virtual std::shared_ptr get_prior() = 0; + //! Default destructor virtual ~AbstractHierarchy() = default; //! Returns an independent, data-less copy of this object @@ -86,7 +100,8 @@ class AbstractHierarchy { const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { if (is_conjugate()) { - throw std::runtime_error("prior_pred_lpdf() not implemented yet"); + throw std::runtime_error( + "prior_pred_lpdf() not implemented for this hierarchy"); } else { throw std::runtime_error( "Cannot call prior_pred_lpdf() from a non-conjugate hierarchy"); @@ -101,7 +116,8 @@ class AbstractHierarchy { const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { if (is_conjugate()) { - throw std::runtime_error("conditional_pred_lpdf() not implemented yet"); + throw std::runtime_error( + "conditional_pred_lpdf() not implemented for this hierarchy"); } else { throw std::runtime_error( "Cannot call conditional_pred_lpdf() from a non-conjugate " @@ -126,7 +142,8 @@ class AbstractHierarchy { const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const { if (is_conjugate()) { - throw std::runtime_error("prior_pred_lpdf_grid() not yet implemented"); + throw std::runtime_error( + "prior_pred_lpdf_grid() not implemented for this hierarchy"); } else { throw std::runtime_error( "Cannot call prior_pred_lpdf_grid() from a non-conjugate hierarchy"); @@ -142,7 +159,7 @@ class AbstractHierarchy { const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const { if (is_conjugate()) { throw std::runtime_error( - "conditional_pred_lpdf_grid() not yet implemented"); + "conditional_pred_lpdf_grid() not implemented for this hierarchy"); } else { throw std::runtime_error( "Cannot call conditional_pred_lpdf_grid() from a non-conjugate " @@ -256,7 +273,8 @@ class AbstractHierarchy { throw std::runtime_error( "Cannot call like_lpdf() from a non-dependent hierarchy"); } else { - throw std::runtime_error("like_lpdf() not implemented"); + throw std::runtime_error( + "like_lpdf() not implemented for this hierarchy"); } } @@ -268,7 +286,8 @@ class AbstractHierarchy { throw std::runtime_error( "Cannot call like_lpdf() from a dependent hierarchy"); } else { - throw std::runtime_error("like_lpdf() not implemented"); + throw std::runtime_error( + "like_lpdf() not implemented for this hierarchy"); } } @@ -284,7 +303,8 @@ class AbstractHierarchy { "Cannot call update_summary_statistics() from a non-dependent " "hierarchy"); } else { - throw std::runtime_error("update_summary_statistics() not implemented"); + throw std::runtime_error( + "update_summary_statistics() not implemented for this hierarchy"); } } @@ -298,7 +318,8 @@ class AbstractHierarchy { "Cannot call update_summary_statistics() from a dependent " "hierarchy"); } else { - throw std::runtime_error("update_summary_statistics() not implemented"); + throw std::runtime_error( + "update_summary_statistics() not implemented for this hierarchy"); } } }; diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 24c1025b1..217a29170 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -38,7 +38,7 @@ class BaseHierarchy : public AbstractHierarchy { //! Container for the prior model of the hierarchy std::shared_ptr prior = std::make_shared(); - //! Container for the update algorithm adopted + //! Container for the update algorithm std::shared_ptr updater; public: @@ -87,12 +87,13 @@ class BaseHierarchy : public AbstractHierarchy { return out; }; + //! Returns an independent, data-less deep copy of this object std::shared_ptr deep_clone() const override { // Create copy of the hierarchy auto out = std::make_shared(static_cast(*this)); - // Simple Clone is enough for Likelihood + // Simple clone for Likelihood is enough out->set_likelihood(std::static_pointer_cast(like->clone())); - // Deep-Clone required for PriorModel + // Deep clone required for PriorModel out->set_prior(std::static_pointer_cast(prior->deep_clone())); return out; } @@ -332,10 +333,13 @@ class BaseHierarchy : public AbstractHierarchy { like->clear_summary_statistics(); }; + //! Returns whether the hierarchy models multivariate data or not bool is_multivariate() const override { return like->is_multivariate(); }; + //! Returns whether the hierarchy depends on covariate values or not bool is_dependent() const override { return like->is_dependent(); }; + //! Returns whether the hierarchy represents a conjugate model or not bool is_conjugate() const override { return updater->is_conjugate(); }; //! Sets the (pointer to the) dataset matrix @@ -347,18 +351,19 @@ class BaseHierarchy : public AbstractHierarchy { //! Initializes state parameters to appropriate values virtual void initialize_state() = 0; - // ADD EXEPTION HANDLING + // ADD EXEPTION HANDLING FOR is_dependent()? virtual double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum) const { if (!is_conjugate()) { throw std::runtime_error( "Call marg_lpdf() for a non-conjugate hierarchy"); } else { - throw std::runtime_error("marg_lpdf() not yet implemented"); + throw std::runtime_error( + "marg_lpdf() not implemented for this hierarchy"); } } - // ADD EXEPTION HANDLING + // ADD EXEPTION HANDLING FOR is_dependent()? virtual double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { @@ -366,7 +371,8 @@ class BaseHierarchy : public AbstractHierarchy { throw std::runtime_error( "Call marg_lpdf() for a non-conjugate hierarchy"); } else { - throw std::runtime_error("marg_lpdf() not yet implemented"); + throw std::runtime_error( + "marg_lpdf() not implemented for this hierarchy"); } } @@ -375,97 +381,5 @@ class BaseHierarchy : public AbstractHierarchy { }; // TODO: Move definitions outside the class to improve code cleaness -// TODO: Move this docs in the right place - -//! Returns the struct of the current prior hyperparameters -// Hyperparams get_hypers() const { return *hypers; } - -//! Returns the struct of the current posterior hyperparameters -// Hyperparams get_posterior_hypers() const { return posterior_hypers; } - -//! Raises an error if the prior pointer is not initialized -// void check_prior_is_set() const { -// if (prior == nullptr) { -// throw std::invalid_argument("Hierarchy prior was not provided"); -// } -// } - -//! Re-initializes the prior of the hierarchy to a newly created object -// void create_empty_prior() { prior.reset(new Prior); } - -//! Sets the cardinality of the cluster -// void set_card(const int card_) { -// card = card_; -// log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); -// } - -//! Writes current state to a Protobuf message and return a shared_ptr -//! New hierarchies have to first modify the field 'oneof val' in the -//! AlgoritmState::ClusterState message by adding the appropriate type -// virtual std::shared_ptr -// get_state_proto() const = 0; - -//! Writes current value of hyperparameters to a Protobuf message and -//! return a shared_ptr. -//! New hierarchies have to first modify the field 'oneof val' in the -//! AlgoritmState::HierarchyHypers message by adding the appropriate type -// virtual std::shared_ptr -// get_hypers_proto() const = 0; - -//! Initializes hierarchy hyperparameters to appropriate values -// virtual void initialize_hypers() = 0; - -//! Resets cardinality and indexes of data in this cluster -// void clear_data() { -// set_card(0); -// cluster_data_idx = std::set(); -// } - -//! Down-casts the given generic proto message to a ClusterState proto -// bayesmix::AlgorithmState::ClusterState *downcast_state( -// google::protobuf::Message *state_) const { -// return google::protobuf::internal::down_cast< -// bayesmix::AlgorithmState::ClusterState *>(state_); -// } - -//! Down-casts the given generic proto message to a ClusterState proto -// const bayesmix::AlgorithmState::ClusterState &downcast_state( -// const google::protobuf::Message &state_) const { -// return google::protobuf::internal::down_cast< -// const bayesmix::AlgorithmState::ClusterState &>(state_); -// } - -//! Down-casts the given generic proto message to a HierarchyHypers proto -// bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( -// google::protobuf::Message *state_) const { -// return google::protobuf::internal::down_cast< -// bayesmix::AlgorithmState::HierarchyHypers *>(state_); -// } - -//! Down-casts the given generic proto message to a HierarchyHypers proto -// const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( -// const google::protobuf::Message &state_) const { -// return google::protobuf::internal::down_cast< -// const bayesmix::AlgorithmState::HierarchyHypers &>(state_); -// } - -// //! Container for prior hyperparameters values -// std::shared_ptr hypers; - -// //! Container for posterior hyperparameters values -// Hyperparams posterior_hypers; - -// //! Pointer to a Protobuf prior object for this class -// std::shared_ptr prior; - -// //! Set of indexes of data points belonging to this cluster -// std::set cluster_data_idx; - -// //! Current cardinality of this cluster -// int card = 0; - -// //! Logarithm of current cardinality of this cluster -// double log_card = stan::math::NEGATIVE_INFTY; -// }; #endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 936d93b18..a6593bec8 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -90,6 +90,9 @@ class AbstractLikelihood { virtual Eigen::VectorXd get_unconstrained_state() = 0; protected: + //! Writes current state to a Protobuf message and return a shared_ptr + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::ClusterState message by adding the appropriate type virtual std::shared_ptr get_state_proto() const = 0; diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index c2072779a..11c41f2b8 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -109,23 +109,27 @@ class BaseLikelihood : public AbstractLikelihood { const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + //! Resets cardinality and indexes of data in this cluster void clear_data() { set_card(0); cluster_data_idx = std::set(); } protected: + //! Sets the cardinality of the cluster void set_card(const int card_) { card = card_; log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); } + //! Down-casts the given generic proto message to a ClusterState proto bayesmix::AlgorithmState::ClusterState *downcast_state( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< bayesmix::AlgorithmState::ClusterState *>(state_); } + //! Down-casts the given generic proto message to a ClusterState proto const bayesmix::AlgorithmState::ClusterState &downcast_state( const google::protobuf::Message &state_) const { return google::protobuf::internal::down_cast< @@ -134,10 +138,13 @@ class BaseLikelihood : public AbstractLikelihood { State state; + //! Current cardinality of this cluster int card = 0; + //! Logarithm of current cardinality of this cluster double log_card = stan::math::NEGATIVE_INFTY; + //! Set of indexes of data points belonging to this cluster std::set cluster_data_idx; }; diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 74f7c03af..0faaecd65 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -56,9 +56,14 @@ class AbstractPriorModel { virtual void write_hypers_to_proto(google::protobuf::Message *out) const = 0; protected: + //! Writes current value of hyperparameters to a Protobuf message and + //! return a shared_ptr. + //! New hierarchies have to first modify the field 'oneof val' in the + //! AlgoritmState::HierarchyHypers message by adding the appropriate type virtual std::shared_ptr get_hypers_proto() const = 0; + //! Initializes hierarchy hyperparameters to appropriate values virtual void initialize_hypers() = 0; }; diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 518d687e9..e6c8fedc2 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -74,8 +74,10 @@ class BasePriorModel : public AbstractPriorModel { virtual google::protobuf::Message *get_mutable_prior() override; + //! Returns the struct of the current prior hyperparameters HyperParams get_hypers() const { return *hypers; } + //! Returns the struct of the current posterior hyperparameters HyperParams get_posterior_hypers() const { return post_hypers; } void set_posterior_hypers(const HyperParams &_post_hypers) { @@ -87,18 +89,24 @@ class BasePriorModel : public AbstractPriorModel { void initialize(); protected: + //! Raises an error if the prior pointer is not initialized void check_prior_is_set() const; + //! Re-initializes the prior of the hierarchy to a newly created object void create_empty_prior() { prior.reset(new Prior); } + //! Re-initializes the hyperparameters of the hierarchy to a newly created + //! object void create_empty_hypers() { hypers.reset(new HyperParams); } + //! Down-casts the given generic proto message to a HierarchyHypers proto bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< bayesmix::AlgorithmState::HierarchyHypers *>(state_); } + //! Down-casts the given generic proto message to a HierarchyHypers proto const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( const google::protobuf::Message &state_) const { return google::protobuf::internal::down_cast< @@ -111,8 +119,13 @@ class BasePriorModel : public AbstractPriorModel { const bayesmix::AlgorithmState::ClusterState &>(state_); } + //! Container for prior hyperparameters values std::shared_ptr hypers = std::make_shared(); + + //! Container for posterior hyperparameters values HyperParams post_hypers; + + //! Pointer to a Protobuf prior object for this class std::shared_ptr prior; }; From 5c3dd62eb86d17bb936c99bbbfac732080e9d651 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 8 Mar 2022 18:54:56 +0100 Subject: [PATCH 191/317] Add documentation --- src/hierarchies/base_hierarchy.h | 43 ++++++++++++++++++- .../likelihoods/abstract_likelihood.h | 42 +++++++++++++++--- 2 files changed, 77 insertions(+), 8 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 217a29170..31fc01508 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -6,7 +6,7 @@ #include #include #include -#include +// #include #include #include "abstract_hierarchy.h" @@ -44,6 +44,8 @@ class BaseHierarchy : public AbstractHierarchy { public: using HyperParams = decltype(prior->get_hypers()); + //! Constructor that allows the specification of Likelihood, PriorModel and + //! Updater for a given Hierarchy BaseHierarchy(std::shared_ptr like_ = nullptr, std::shared_ptr prior_ = nullptr, std::shared_ptr updater_ = nullptr) { @@ -60,21 +62,30 @@ class BaseHierarchy : public AbstractHierarchy { } } + //! Default destructor ~BaseHierarchy() = default; + //! Set the likelihood for the current hierarchy void set_likelihood(std::shared_ptr like_) override { like = std::static_pointer_cast(like_); } + + //! Set the prior model for the current hierarchy void set_prior(std::shared_ptr prior_) override { prior = std::static_pointer_cast(prior_); } + + //! Set the update algorithm for the current hierarchy void set_updater(std::shared_ptr updater_) override { updater = updater_; }; + //! Returns (a pointer to) the likelihood for the current hierarchy std::shared_ptr get_likelihood() override { return like; } + + //! Returns (a pointer to) the prior model for the current hierarchy. std::shared_ptr get_prior() override { return prior; } //! Returns an independent, data-less copy of this object @@ -118,6 +129,7 @@ class BaseHierarchy : public AbstractHierarchy { // return out; // } + //! Public wrapper for `like_lpdf()` methods double get_like_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const override { @@ -135,6 +147,7 @@ class BaseHierarchy : public AbstractHierarchy { }; // ADD EXCEPTION HANDLING + //! Public wrapper for `marg_lpdf()` methods double get_marg_lpdf( const HyperParams ¶ms, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { @@ -146,6 +159,10 @@ class BaseHierarchy : public AbstractHierarchy { } // ADD EXCEPTION HANDLING + //! Evaluates the log-prior predictive distribution of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate (Optional) covariate vector associated to datum + //! @return The evaluation of the lpdf double prior_pred_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const override { @@ -153,6 +170,10 @@ class BaseHierarchy : public AbstractHierarchy { } // ADD EXCEPTION HANDLING + //! Evaluates the log-prior predictive distr. of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf Eigen::VectorXd prior_pred_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) @@ -181,6 +202,10 @@ class BaseHierarchy : public AbstractHierarchy { } // ADD EXCEPTION HANDLING + //! Evaluates the log-conditional predictive distr. of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate (Optional) covariate vector associated to datum + //! @return The evaluation of the lpdf double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const override { @@ -188,6 +213,10 @@ class BaseHierarchy : public AbstractHierarchy { } // ADD EXCEPTION HANDLING + //! Evaluates the log-prior predictive distr. of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf Eigen::VectorXd conditional_pred_lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) @@ -254,6 +283,7 @@ class BaseHierarchy : public AbstractHierarchy { static_cast(this)->sample_full_cond(true); }; + //! Updates hyperparameter values given a vector of cluster states void update_hypers(const std::vector &states) override { prior->update_hypers(states); @@ -297,10 +327,12 @@ class BaseHierarchy : public AbstractHierarchy { prior->write_hypers_to_proto(out); }; + //! Read and set state values from a given Protobuf message void set_state_from_proto(const google::protobuf::Message &state_) override { like->set_state_from_proto(state_); }; + //! Read and set hyperparameter values from a given Protobuf message void set_hypers_from_proto( const google::protobuf::Message &state_) override { prior->set_hypers_from_proto(state_); @@ -352,6 +384,10 @@ class BaseHierarchy : public AbstractHierarchy { virtual void initialize_state() = 0; // ADD EXEPTION HANDLING FOR is_dependent()? + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf virtual double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum) const { if (!is_conjugate()) { @@ -364,6 +400,11 @@ class BaseHierarchy : public AbstractHierarchy { } // ADD EXEPTION HANDLING FOR is_dependent()? + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf virtual double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index a6593bec8..ba4a6b8bf 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -11,11 +11,14 @@ class AbstractLikelihood { public: + //! Default destructor virtual ~AbstractLikelihood() = default; - // IMPLEMENTED in BaseLikelihood + //! Returns an independent, data-less copy of this object. Implemented in + //! BaseLikelihood virtual std::shared_ptr clone() const = 0; + //! Public wrapper for `compute_lpdf()` methods double lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { @@ -45,36 +48,45 @@ class AbstractLikelihood { "cluster_lpdf_from_unconstrained() not yet implemented"); } + //! Evaluates the log-likelihood of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf virtual Eigen::VectorXd lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const = 0; - // AGGIUNGERE CLUST_LPDF (CHE VALUTA LA LIKELIHOOD CONGIUNTA SU TUTTO IL - // CLUSTER) - + //! Returns whether the likelihood models multivariate data or not virtual bool is_multivariate() const = 0; + //! Returns whether the likelihood depends on covariate values or not virtual bool is_dependent() const = 0; + //! Read and set state values from a given Protobuf message virtual void set_state_from_proto(const google::protobuf::Message &state_, bool update_card = true) = 0; + //! Read and set state values from the vector of unconstrained parameters virtual void set_state_from_unconstrained( const Eigen::VectorXd &unconstrained_state) = 0; - // IMPLEMENTED in BaseLikelihood + //! Writes current state to a Protobuf message by pointer. Implemented in + //! BaseLikelihood virtual void write_state_to_proto(google::protobuf::Message *out) const = 0; - // IMPLEMENTED in BaseLikelihood + //! Adds a datum and its index to the likelihood. Implemented in + //! BaseLikelihood virtual void add_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; - // IMPLEMENTED in BaseLikelihood + //! Removes a datum and its index from the likelihood. Implemented in + //! BaseLikelihood virtual void remove_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; + //! Public wrapper for `update_sum_stats()` methods void update_summary_statistics(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate, bool add) { @@ -85,8 +97,10 @@ class AbstractLikelihood { } } + //! Resets the values of the summary statistics in the likelihood virtual void clear_summary_statistics() = 0; + //! Returns the vector of the unconstrained parameters for this likelihood virtual Eigen::VectorXd get_unconstrained_state() = 0; protected: @@ -96,6 +110,9 @@ class AbstractLikelihood { virtual std::shared_ptr get_state_proto() const = 0; + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf virtual double compute_lpdf(const Eigen::RowVectorXd &datum) const { if (is_dependent()) { throw std::runtime_error( @@ -106,6 +123,10 @@ class AbstractLikelihood { } } + //! Evaluates the log-likelihood of data in a single point + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf virtual double compute_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { if (!is_dependent()) { @@ -117,6 +138,9 @@ class AbstractLikelihood { } } + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param add Whether the datum is being added or removed virtual void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) { if (is_dependent()) { throw std::runtime_error( @@ -126,6 +150,10 @@ class AbstractLikelihood { } } + //! Updates cluster statistics when a datum is added or removed from it + //! @param datum Data point which is being added or removed + //! @param covariate Covariate vector associated to datum + //! @param add Whether the datum is being added or removed virtual void update_sum_stats(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate, bool add) { From c13bc6578e88c364f24e592eee1a87a0d249ccfd Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 9 Mar 2022 10:01:28 +0100 Subject: [PATCH 192/317] Improved documentation --- .../likelihoods/abstract_likelihood.h | 19 ++++- src/hierarchies/likelihoods/base_likelihood.h | 73 ++++++++++++++----- src/hierarchies/priors/abstract_prior_model.h | 34 ++++++++- src/hierarchies/priors/base_prior_model.h | 59 +++++++++++---- src/hierarchies/updaters/abstract_updater.h | 14 +++- 5 files changed, 158 insertions(+), 41 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index ba4a6b8bf..705efbb06 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -35,17 +35,32 @@ class AbstractLikelihood { //! the parameter vector can range over (-inf, inf). //! Usually, some kind of transformation is required from the unconstrained //! parameterization to the actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood over all data in the cluster virtual double cluster_lpdf_from_unconstrained( Eigen::VectorXd unconstrained_params) const { throw std::runtime_error( - "cluster_lpdf_from_unconstrained() not yet implemented"); + "cluster_lpdf_from_unconstrained() not implemented for this " + "likelihood"); } + //! Evaluates the log likelihood over all the data in the cluster + //! given unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. This version using + //! `stan::math::var` type is required for Stan automatic aifferentiation. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood over all data in the cluster virtual stan::math::var cluster_lpdf_from_unconstrained( Eigen::Matrix unconstrained_params) const { throw std::runtime_error( - "cluster_lpdf_from_unconstrained() not yet implemented"); + "cluster_lpdf_from_unconstrained() not implemented for this " + "likelihood"); } //! Evaluates the log-likelihood of data in a grid of points diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 11c41f2b8..12465a63e 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -36,9 +36,13 @@ auto cluster_lpdf_from_unconstrained( template class BaseLikelihood : public AbstractLikelihood { public: + //! Default constructor BaseLikelihood() = default; + + //! Default destructor ~BaseLikelihood() = default; + //! Returns an independent, data-less copy of this object virtual std::shared_ptr clone() const override { auto out = std::make_shared(static_cast(*this)); out->clear_data(); @@ -46,29 +50,31 @@ class BaseLikelihood : public AbstractLikelihood { return out; } - // The unconstrained parameters are mean and log(var) - - // double cluster_lpdf_from_unconstrained( - // Eigen::VectorXd unconstrained_params) const override { - // return static_cast(*this) - // .template cluster_lpdf_from_unconstrained( - // unconstrained_params); - // } - - // stan::math::var cluster_lpdf_from_unconstrained( - // Eigen::Matrix - // unconstrained_params) const override { - // return static_cast(*this) - // .template cluster_lpdf_from_unconstrained( - // unconstrained_params); - // } - + //! Evaluates the log likelihood over all the data in the cluster + //! given unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood over all data in the cluster double cluster_lpdf_from_unconstrained( Eigen::VectorXd unconstrained_params) const override { return internal::cluster_lpdf_from_unconstrained( static_cast(*this), unconstrained_params, 0); } + //! Evaluates the log likelihood over all the data in the cluster + //! given unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. This version using + //! `stan::math::var` type is required for Stan automatic aifferentiation. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood over all data in the cluster stan::math::var cluster_lpdf_from_unconstrained( Eigen::Matrix unconstrained_params) const override { @@ -76,35 +82,49 @@ class BaseLikelihood : public AbstractLikelihood { static_cast(*this), unconstrained_params, 0); } + //! Evaluates the log-likelihood of data in a grid of points + //! @param data Grid of points (by row) which are to be evaluated + //! @param covariates (Optional) covariate vectors associated to data + //! @return The evaluation of the lpdf virtual Eigen::VectorXd lpdf_grid(const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) const override; + //! Returns the current cardinality of the cluster int get_card() const { return card; } + //! Returns the logarithm of the current cardinality of the cluster double get_log_card() const { return log_card; } + //! Returns the indexes of data points belonging to this cluster std::set get_data_idx() const { return cluster_data_idx; } + //! Writes current state to a Protobuf message by pointer void write_state_to_proto(google::protobuf::Message *out) const override; + //! Returns the class of the current state for the likelihood State get_state() const { return state; } + //! Returns a vector storing the state in its unconstrained form Eigen::VectorXd get_unconstrained_state() override { return state.get_unconstrained(); } + //! Updates the state of the likelihood with the object given as input void set_state(const State &_state) { state = _state; }; + //! Updates the state of the likelihood starting from its unconstrained form void set_state_from_unconstrained( const Eigen::VectorXd &unconstrained_state) override { state.set_from_unconstrained(unconstrained_state); } + //! Adds a datum and its index to the likelihood void add_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; + //! Removes a datum and its index from the likelihood void remove_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; @@ -136,6 +156,7 @@ class BaseLikelihood : public AbstractLikelihood { const bayesmix::AlgorithmState::ClusterState &>(state_); } + //! Current state of this cluster State state; //! Current cardinality of this cluster @@ -208,4 +229,22 @@ Eigen::VectorXd BaseLikelihood::lpdf_grid( return lpdf; } +// OLD STUFF +// The unconstrained parameters are mean and log(var) + +// double cluster_lpdf_from_unconstrained( +// Eigen::VectorXd unconstrained_params) const override { +// return static_cast(*this) +// .template cluster_lpdf_from_unconstrained( +// unconstrained_params); +// } + +// stan::math::var cluster_lpdf_from_unconstrained( +// Eigen::Matrix +// unconstrained_params) const override { +// return static_cast(*this) +// .template cluster_lpdf_from_unconstrained( +// unconstrained_params); +// } + #endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 0faaecd65..8531264fa 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -13,14 +13,21 @@ class AbstractPriorModel { public: + //! Default destructor virtual ~AbstractPriorModel() = default; - // IMPLEMENTED in BasePriorModel + //! Returns an independent, data-less copy of this object. Implemented in + //! BasePriorModel virtual std::shared_ptr clone() const = 0; - // IMPLEMENTED in BasePriorModel + //! Returns an independent, data-less deep copy of this object. Implemented + //! in BasePriorModel virtual std::shared_ptr deep_clone() const = 0; + //! Evaluates the log likelihood for the prior model, given the state of the + //! cluster + //! @param state_ A Protobuf message storing the current state of the cluster + //! @return The evaluation of the log likelihood virtual double lpdf(const google::protobuf::Message &state_) = 0; //! Evaluates the log likelihood for unconstrained parameter values. @@ -28,11 +35,23 @@ class AbstractPriorModel { //! the parameter vector can range over (-inf, inf). //! Usually, some kind of transformation is required from the unconstrained //! parameterization to the actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood of the prior model virtual double lpdf_from_unconstrained( Eigen::VectorXd unconstrained_params) const { throw std::runtime_error("lpdf_from_unconstrained() not yet implemented"); } + //! Evaluates the log likelihood for unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. This version using + //! `stan::math::var` type is required for Stan automatic aifferentiation. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood of the prior model virtual stan::math::var lpdf_from_unconstrained( Eigen::Matrix unconstrained_params) const { @@ -40,19 +59,26 @@ class AbstractPriorModel { "cluster_lpdf_from_unconstrained() not yet implemented"); } - // Da pensare, come restituisco lo stato? magari un pointer? Oppure delego + //! Sampling from the prior model + //! @param use_post_hypers It is a `bool` which decides whether to use prior + //! or posterior parameters + //! @return A Protobuf message storing the state sampled from the prior model virtual std::shared_ptr sample( bool use_post_hypers) = 0; + //! Updates hyperparameter values given a vector of cluster states virtual void update_hypers( const std::vector &states) = 0; + //! Returns a pointer to the Protobuf message of the prior of this cluster virtual google::protobuf::Message *get_mutable_prior() = 0; + //! Read and set hyperparameter values from a given Protobuf message virtual void set_hypers_from_proto( const google::protobuf::Message &state_) = 0; - // IMPLEMENTED in BasePriorModel + //! Writes current values of the hyperparameters to a Protobuf message by + //! pointer. Implemented in BasePriorModel virtual void write_hypers_to_proto(google::protobuf::Message *out) const = 0; protected: diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index e6c8fedc2..8c43d39fb 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -37,16 +37,35 @@ auto lpdf_from_unconstrained( template class BasePriorModel : public AbstractPriorModel { public: + //! Default constructor BasePriorModel() = default; + //! Default destructor ~BasePriorModel() = default; + //! Evaluates the log likelihood for unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood of the prior model double lpdf_from_unconstrained( Eigen::VectorXd unconstrained_params) const override { return internal::lpdf_from_unconstrained( static_cast(*this), unconstrained_params, 0); } + //! Evaluates the log likelihood for unconstrained parameter values. + //! By unconstrained parameters we mean that each entry of + //! the parameter vector can range over (-inf, inf). + //! Usually, some kind of transformation is required from the unconstrained + //! parameterization to the actual parameterization. This version using + //! `stan::math::var` type is required for Stan automatic aifferentiation. + //! @param unconstrained_params vector collecting the unconstrained + //! parameters + //! @return The evaluation of the log likelihood of the prior model stan::math::var lpdf_from_unconstrained( Eigen::Matrix unconstrained_params) const override { @@ -54,24 +73,13 @@ class BasePriorModel : public AbstractPriorModel { static_cast(*this), unconstrained_params, 0); } - // double lpdf_from_unconstrained( - // Eigen::VectorXd unconstrained_params) const override { - // return static_cast(*this) - // .template lpdf_from_unconstrained(unconstrained_params); - // } - - // stan::math::var lpdf_from_unconstrained( - // Eigen::Matrix - // unconstrained_params) const override { - // return static_cast(*this) - // .template lpdf_from_unconstrained( - // unconstrained_params); - // } - + //! Returns an independent, data-less copy of this object virtual std::shared_ptr clone() const override; + //! Returns an independent, data-less deep copy of this object virtual std::shared_ptr deep_clone() const override; + //! Returns a pointer to the Protobuf message of the prior of this cluster virtual google::protobuf::Message *get_mutable_prior() override; //! Returns the struct of the current prior hyperparameters @@ -80,12 +88,16 @@ class BasePriorModel : public AbstractPriorModel { //! Returns the struct of the current posterior hyperparameters HyperParams get_posterior_hypers() const { return post_hypers; } + //! Updates the current value of the posterior hyperparameters void set_posterior_hypers(const HyperParams &_post_hypers) { post_hypers = _post_hypers; }; + //! Writes current values of the hyperparameters to a Protobuf message by + //! pointer void write_hypers_to_proto(google::protobuf::Message *out) const override; + //! Initializes the prior model (both prior and hyperparameters) void initialize(); protected: @@ -113,6 +125,7 @@ class BasePriorModel : public AbstractPriorModel { const bayesmix::AlgorithmState::HierarchyHypers &>(state_); } + //! Down-casts the given generic proto message to a ClusterState proto const bayesmix::AlgorithmState::ClusterState &downcast_state( const google::protobuf::Message &state_) const { return google::protobuf::internal::down_cast< @@ -129,8 +142,7 @@ class BasePriorModel : public AbstractPriorModel { std::shared_ptr prior; }; -// Methods Definitions - +/* *** Methods Definitions *** */ template std::shared_ptr BasePriorModel::clone() const { @@ -192,4 +204,19 @@ void BasePriorModel::check_prior_is_set() const { } } +// OLD STUFF +// double lpdf_from_unconstrained( +// Eigen::VectorXd unconstrained_params) const override { +// return static_cast(*this) +// .template lpdf_from_unconstrained(unconstrained_params); +// } + +// stan::math::var lpdf_from_unconstrained( +// Eigen::Matrix +// unconstrained_params) const override { +// return static_cast(*this) +// .template lpdf_from_unconstrained( +// unconstrained_params); +// } + #endif // BAYESMIX_HIERARCHIES_PRIORS_BASE_PRIOR_MODEL_H_ diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 188c4e46c..4c412397e 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -3,20 +3,30 @@ #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" -#include "src/hierarchies/updaters/target_lpdf_unconstrained.h" +// #include "src/hierarchies/updaters/target_lpdf_unconstrained.h" class AbstractUpdater { public: + //! Default destructor virtual ~AbstractUpdater() = default; + //! Returns whether the current updater is for conjugate model or not virtual bool is_conjugate() const { return false; }; + //! Sampling from the full conditional, given the likelihood and the prior + //! model that constitutes the hierarchy + //! @param like The likelihood of the hierarchy + //! @param prior The prior model of the hierarchy + //! @param update_params Save posterior hyperparameters after draw? virtual void draw(AbstractLikelihood &like, AbstractPriorModel &prior, bool update_params) = 0; + //! Computes the posterior hyperparameters required for the sampling in case + //! of conjugate hierarchies virtual void compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) { - throw std::runtime_error("compute_posterior_hypers not implemented"); + throw std::runtime_error( + "compute_posterior_hypers() not implemented for this updater"); } }; From 2c07875630b728bdfa641fbee8a776b404466aa3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Mar 2022 17:04:37 +0100 Subject: [PATCH 193/317] Improved SFINAE for exeption handling --- src/hierarchies/likelihoods/base_likelihood.h | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 12465a63e..1e48ee64f 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -13,6 +13,7 @@ namespace internal { +/* SFINAE for cluster_lpdf_from_unconstrained() */ template auto cluster_lpdf_from_unconstrained( const Like &like, Eigen::Matrix unconstrained_params, @@ -22,7 +23,6 @@ auto cluster_lpdf_from_unconstrained( return like.template cluster_lpdf_from_unconstrained( unconstrained_params); } - template auto cluster_lpdf_from_unconstrained( const Like &like, Eigen::Matrix unconstrained_params, @@ -31,6 +31,33 @@ auto cluster_lpdf_from_unconstrained( "cluster_lpdf_from_unconstrained() not yet implemented")); } +/* SFINAE for get_unconstrained_state() */ +template +auto get_unconstrained_state(const State &state, int) + -> decltype(state.get_unconstrained()) { + return state.get_unconstrained(); +} +template +auto get_unconstrained_state(const State &state, double) -> Eigen::VectorXd { + throw(std::runtime_error("get_unconstrained_state() not yet implemented")); +} + +/* SFINAE for set_state_from_unconstrained() */ +template +auto set_state_from_unconstrained(State &state, + const Eigen::VectorXd &unconstrained_state, + int) + -> decltype(state.set_from_unconstrained(unconstrained_state)) { + state.set_from_unconstrained(unconstrained_state); +} +template +auto set_state_from_unconstrained(State &state, + const Eigen::VectorXd &unconstrained_state, + double) -> void { + throw(std::runtime_error( + "set_state_from_unconstrained() not yet implemented")); +} + } // namespace internal template @@ -107,7 +134,8 @@ class BaseLikelihood : public AbstractLikelihood { //! Returns a vector storing the state in its unconstrained form Eigen::VectorXd get_unconstrained_state() override { - return state.get_unconstrained(); + return internal::get_unconstrained_state(state, 0); + // return state.get_unconstrained(); } //! Updates the state of the likelihood with the object given as input @@ -116,7 +144,8 @@ class BaseLikelihood : public AbstractLikelihood { //! Updates the state of the likelihood starting from its unconstrained form void set_state_from_unconstrained( const Eigen::VectorXd &unconstrained_state) override { - state.set_from_unconstrained(unconstrained_state); + internal::set_state_from_unconstrained(state, unconstrained_state, 0); + // state.set_from_unconstrained(unconstrained_state); } //! Adds a datum and its index to the likelihood From 4f1d86cfd01ee477a16e86b95db670fcadf6103c Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Mar 2022 17:05:48 +0100 Subject: [PATCH 194/317] Uncomment FAState --- src/proto/algorithm_state.proto | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/proto/algorithm_state.proto b/src/proto/algorithm_state.proto index 0f387eda9..8975320dc 100644 --- a/src/proto/algorithm_state.proto +++ b/src/proto/algorithm_state.proto @@ -25,7 +25,7 @@ message AlgorithmState { MultiLSState multi_ls_state = 2; // State of a multivariate location-scale family LinRegUniLSState lin_reg_uni_ls_state = 4; // State of a linear regression univariate location-scale family Vector general_state = 5; // Just a vector of doubles - // FAState fa_state = 6; // State of a Mixture of Factor Analysers + FAState fa_state = 6; // State of a Mixture of Factor Analysers } int32 cardinality = 3; // How many observations are in this cluster } @@ -44,7 +44,7 @@ message AlgorithmState { MultiNormalIGDistribution lin_reg_uni_state = 4; NxIGDistribution nnxig_state = 5; // LapNIGState lapnig_state = 6; - // FAPriorDistribution fa_state = 7; + FAPriorDistribution fa_state = 7; } } HierarchyHypers hierarchy_hypers = 5; // The current values of the hyperparameters of the hierarchy From 53a14154e6491716b7acfcd1330b58ad9526b485 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 18 Mar 2022 17:06:30 +0100 Subject: [PATCH 195/317] Add fa_likelihood (ONGOING) --- src/hierarchies/likelihoods/CMakeLists.txt | 2 + src/hierarchies/likelihoods/fa_likelihood.cc | 57 +++++++++++++++++++ src/hierarchies/likelihoods/fa_likelihood.h | 46 +++++++++++++++ .../likelihoods/states/CMakeLists.txt | 1 + src/hierarchies/likelihoods/states/fa_state.h | 23 ++++++++ src/hierarchies/likelihoods/states/includes.h | 1 + 6 files changed, 130 insertions(+) create mode 100644 src/hierarchies/likelihoods/fa_likelihood.cc create mode 100644 src/hierarchies/likelihoods/fa_likelihood.h create mode 100644 src/hierarchies/likelihoods/states/fa_state.h diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 42cc97da3..197d4106b 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -9,6 +9,8 @@ target_sources(bayesmix PUBLIC uni_lin_reg_likelihood.cc laplace_likelihood.h laplace_likelihood.cc + fa_likelihood.h + fa_likelihood.cc ) add_subdirectory(states) diff --git a/src/hierarchies/likelihoods/fa_likelihood.cc b/src/hierarchies/likelihoods/fa_likelihood.cc new file mode 100644 index 000000000..ae843598b --- /dev/null +++ b/src/hierarchies/likelihoods/fa_likelihood.cc @@ -0,0 +1,57 @@ +#include "fa_likelihood.h" + +#include "src/utils/distributions.h" + +void FALikelihood::set_state_from_proto( + const google::protobuf::Message& state_, bool update_card) { + auto& statecast = downcast_state(state_); + state.mu = bayesmix::to_eigen(statecast.fa_state().mu()); + state.psi = bayesmix::to_eigen(statecast.fa_state().psi()); + state.eta = bayesmix::to_eigen(statecast.fa_state().eta()); + state.lambda = bayesmix::to_eigen(statecast.fa_state().lambda()); + state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); + compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, + state.psi_inverse); + if (update_card) set_card(statecast.cardinality()); +} + +void FALikelihood::clear_summary_statistics() { + data_sum = Eigen::VectorXd::Zero(dim); +} + +std::shared_ptr +FALikelihood::get_state_proto() const { + bayesmix::FAState state_; + bayesmix::to_proto(state.mu, state_.mutable_mu()); + bayesmix::to_proto(state.psi, state_.mutable_psi()); + bayesmix::to_proto(state.eta, state_.mutable_eta()); + bayesmix::to_proto(state.lambda, state_.mutable_lambda()); + + auto out = std::make_shared(); + out->mutable_fa_state()->CopyFrom(state_); + return out; +} + +double FALikelihood::compute_lpdf(const Eigen::RowVectorXd& datum) const { + return bayesmix::multi_normal_lpdf_woodbury_chol( + datum, state.mu, state.psi_inverse, state.cov_wood, state.cov_logdet); +} + +void FALikelihood::update_sum_stats(const Eigen::RowVectorXd& datum, + bool add) { + if (add) { + data_sum += datum; + } else { + data_sum -= datum; + } +} + +void FALikelihood::compute_wood_factors( + Eigen::MatrixXd& cov_wood, double& cov_logdet, + const Eigen::MatrixXd& lambda, + const Eigen::DiagonalMatrix& psi_inverse) { + auto [cov_wood_, cov_logdet_] = + bayesmix::compute_wood_chol_and_logdet(psi_inverse, lambda); + cov_logdet = cov_logdet_; + cov_wood = cov_wood_; +} diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h new file mode 100644 index 000000000..d9cd5a77e --- /dev/null +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -0,0 +1,46 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_FA_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_FA_LIKELIHOOD_H_ + +#include + +#include +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +class FALikelihood : public BaseLikelihood { + public: + FALikelihood() = default; + ~FALikelihood() = default; + bool is_multivariate() const override { return true; }; + bool is_dependent() const override { return false; }; + void set_state_from_proto(const google::protobuf::Message& state_, + bool update_card = true) override; + void clear_summary_statistics() override; + void set_dim(unsigned int dim_) { + dim = dim_; + clear_summary_statistics(); + }; + unsigned int get_dim() const { return dim; }; + Eigen::VectorXd get_data_sum() const { return data_sum; }; + + std::shared_ptr get_state_proto() + const override; + + protected: + double compute_lpdf(const Eigen::RowVectorXd& datum) const override; + void update_sum_stats(const Eigen::RowVectorXd& datum, bool add) override; + void compute_wood_factors( + Eigen::MatrixXd& cov_wood, double& cov_logdet, + const Eigen::MatrixXd& lambda, + const Eigen::DiagonalMatrix& psi_inverse); + + unsigned int dim; + Eigen::VectorXd data_sum; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_FA_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt index 3c5a67426..89da41e53 100644 --- a/src/hierarchies/likelihoods/states/CMakeLists.txt +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -2,5 +2,6 @@ target_sources(bayesmix PUBLIC uni_ls_state.h multi_ls_state.h uni_lin_reg_ls_state.h + fa_state.h includes.h ) diff --git a/src/hierarchies/likelihoods/states/fa_state.h b/src/hierarchies/likelihoods/states/fa_state.h new file mode 100644 index 000000000..a876aaf1e --- /dev/null +++ b/src/hierarchies/likelihoods/states/fa_state.h @@ -0,0 +1,23 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_FA_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_FA_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "src/utils/eigen_utils.h" +#include "src/utils/proto_utils.h" + +namespace State { + +class FA { + public: + Eigen::VectorXd mu, psi; + Eigen::MatrixXd eta, lambda, cov_wood; + Eigen::DiagonalMatrix psi_inverse; + double cov_logdet; +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_FACTOR_ANALYZERS_STATE_H_ diff --git a/src/hierarchies/likelihoods/states/includes.h b/src/hierarchies/likelihoods/states/includes.h index b1282fb6e..f4f868c52 100644 --- a/src/hierarchies/likelihoods/states/includes.h +++ b/src/hierarchies/likelihoods/states/includes.h @@ -1,6 +1,7 @@ #ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ #define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_INCLUDES_H_ +#include "fa_state.h" #include "multi_ls_state.h" #include "uni_lin_reg_ls_state.h" #include "uni_ls_state.h" From 8ff95e7f71af2333b94bfd5b5bef99d0570e957d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 19 Mar 2022 15:46:43 +0100 Subject: [PATCH 196/317] Fix set_dataset method --- src/hierarchies/abstract_hierarchy.h | 3 ++- src/hierarchies/base_hierarchy.h | 4 ++-- src/hierarchies/likelihoods/abstract_likelihood.h | 3 +++ src/hierarchies/likelihoods/base_likelihood.h | 8 ++++++++ 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 11bff7549..692b80b71 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -259,7 +259,8 @@ class AbstractHierarchy { //! Returns whether the hierarchy represents a conjugate model or not virtual bool is_conjugate() const = 0; - //! Main function that initializes members to appropriate values + //! Sets the (pointer to) the dataset in the cluster. Implemented in + //! BaseHierarchy virtual void set_dataset(const Eigen::MatrixXd *const dataset) = 0; protected: diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 31fc01508..d78db2ba6 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -376,7 +376,7 @@ class BaseHierarchy : public AbstractHierarchy { //! Sets the (pointer to the) dataset matrix void set_dataset(const Eigen::MatrixXd *const dataset) override { - dataset_ptr = dataset; + like->set_dataset(dataset); } protected: @@ -418,7 +418,7 @@ class BaseHierarchy : public AbstractHierarchy { } // TEMPORANEO! - const Eigen::MatrixXd *dataset_ptr; + // const Eigen::MatrixXd *dataset_ptr; }; // TODO: Move definitions outside the class to improve code cleaness diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 705efbb06..a1739c689 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -89,6 +89,9 @@ class AbstractLikelihood { //! BaseLikelihood virtual void write_state_to_proto(google::protobuf::Message *out) const = 0; + //! Sets the (pointer to) the dataset in the cluster + virtual void set_dataset(const Eigen::MatrixXd *const dataset) = 0; + //! Adds a datum and its index to the likelihood. Implemented in //! BaseLikelihood virtual void add_datum( diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 1e48ee64f..0ebff05fe 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -148,6 +148,11 @@ class BaseLikelihood : public AbstractLikelihood { // state.set_from_unconstrained(unconstrained_state); } + //! Sets the (pointer to) the dataset in the cluster + void set_dataset(const Eigen::MatrixXd *const dataset) { + dataset_ptr = dataset; + } + //! Adds a datum and its index to the likelihood void add_datum( const int id, const Eigen::RowVectorXd &datum, @@ -196,6 +201,9 @@ class BaseLikelihood : public AbstractLikelihood { //! Set of indexes of data points belonging to this cluster std::set cluster_data_idx; + + //! Pointer to the cluster dataset + const Eigen::MatrixXd *dataset_ptr; }; template From fa71470cf398b5c64eaa5c631a503d606688021e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 19 Mar 2022 15:47:03 +0100 Subject: [PATCH 197/317] Add TODO for later fix --- src/hierarchies/likelihoods/laplace_likelihood.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 37cd611ff..62d8fc14c 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -44,6 +44,7 @@ class LaplaceLikelihood double compute_lpdf(const Eigen::RowVectorXd &datum) const override; void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + // TODO: ORA CHE HO IL DATASET QUESTO NON SERVE! //! Set of values of data points belonging to this cluster std::list cluster_data_values; //! Sum of absolute differences for current params From 5c11ad37653f9fa3589342936758f9c3d38e6500 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 19 Mar 2022 15:47:32 +0100 Subject: [PATCH 198/317] Improved code --- src/hierarchies/priors/nw_prior_model.cc | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc index ab7090dc8..f0bb2dfe7 100644 --- a/src/hierarchies/priors/nw_prior_model.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -136,19 +136,12 @@ std::shared_ptr NWPriorModel::sample( params.mean, tau_new * params.var_scaling, rng); write_prec_to_state(tau_new, &out); - // Translate to proto - bayesmix::Vector mean_proto; - bayesmix::Matrix prec_proto, prec_chol_proto; - bayesmix::to_proto(out.mean, &mean_proto); - bayesmix::to_proto(out.prec, &prec_proto); - bayesmix::to_proto(out.prec_chol, &prec_chol_proto); - // Make output state bayesmix::AlgorithmState::ClusterState state; - state.mutable_multi_ls_state()->mutable_mean()->CopyFrom(mean_proto); - state.mutable_multi_ls_state()->mutable_prec()->CopyFrom(prec_proto); - state.mutable_multi_ls_state()->mutable_prec_chol()->CopyFrom( - prec_chol_proto); + bayesmix::to_proto(out.mean, state.mutable_multi_ls_state()->mutable_mean()); + bayesmix::to_proto(out.prec, state.mutable_multi_ls_state()->mutable_prec()); + bayesmix::to_proto(out.prec_chol, + state.mutable_multi_ls_state()->mutable_prec_chol()); return std::make_shared(state); }; From 999ae01a131e56322c902a7802a2736461d42809 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Sat, 19 Mar 2022 15:48:31 +0100 Subject: [PATCH 199/317] Add FAPriorModel (ONGOING) --- src/hierarchies/priors/CMakeLists.txt | 2 + src/hierarchies/priors/fa_prior_model.cc | 178 +++++++++++++++++++++++ src/hierarchies/priors/fa_prior_model.h | 44 ++++++ src/hierarchies/priors/hyperparams.h | 6 + src/proto/hierarchy_prior.proto | 1 + 5 files changed, 231 insertions(+) create mode 100644 src/hierarchies/priors/fa_prior_model.cc create mode 100644 src/hierarchies/priors/fa_prior_model.h diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt index 5b218e6c0..ddbf6e4da 100644 --- a/src/hierarchies/priors/CMakeLists.txt +++ b/src/hierarchies/priors/CMakeLists.txt @@ -10,4 +10,6 @@ target_sources(bayesmix PUBLIC nw_prior_model.cc mnig_prior_model.h mnig_prior_model.cc + fa_prior_model.h + fa_prior_model.cc ) diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc new file mode 100644 index 000000000..ccc541f7a --- /dev/null +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -0,0 +1,178 @@ +#include "fa_prior_model.h" + +double FAPriorModel::lpdf(const google::protobuf::Message &state_) { + // Downcast state + auto &state = downcast_state(state_).fa_state(); + // Proto to Eigen conversion + Eigen::VectorXd mu = bayesmix::to_eigen(state.mu()); + Eigen::VectorXd psi = bayesmix::to_eigen(state.psi()); + Eigen::MatrixXd eta = bayesmix::to_eigen(state.eta()); + Eigen::MatrixXd lambda = bayesmix::to_eigen(state.lambda()); + // Initialize lpdf value + double target = 0.; + // Compute lpdf + for (size_t j = 0; j < dim; j++) { + target += + stan::math::normal_lpdf(mu(j), hypers->mutilde(j), sqrt(hypers->phi)); + target += + stan::math::inv_gamma_lpdf(psi(j), hypers->alpha0, hypers->beta(j)); + for (size_t i = 0; i < hypers->q; i++) { + target += stan::math::normal_lpdf(lambda(j, i), 0, 1); + } + } + for (size_t i = 0; i < eta.rows(); i++) { + for (size_t j = 0; j < hypers->q; j++) { + target += stan::math::normal_lpdf(eta(i, j), 0, 1); + } + } + // Return lpdf contribution + return target; +} + +std::shared_ptr FAPriorModel::sample( + bool use_post_hypers) { + // Random seed + auto &rng = bayesmix::Rng::Instance().get(); + + // Select params to use + Hyperparams::FA params = use_post_hypers ? post_hypers : *hypers; + + // HO AGGIUNTO PARAMS.CARD MA NON SO SE SIA LA SCELTA MIGLIORE!!! + // Compute output state + State::FA out; + out.mu = params.mutilde; + out.psi = params.beta / (params.alpha0 + 1.); + out.eta = Eigen::MatrixXd::Zero(params.card, params.q); + out.lambda = Eigen::MatrixXd::Zero(dim, params.q); + for (size_t j = 0; j < dim; j++) { + out.mu[j] = + stan::math::normal_rng(params.mutilde[j], sqrt(params.phi), rng); + + out.psi[j] = stan::math::inv_gamma_rng(params.alpha0, params.beta[j], rng); + + for (size_t i = 0; i < params.q; i++) { + out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); + } + } + for (size_t i = 0; i < params.card; i++) { + for (size_t j = 0; j < params.q; j++) { + out.eta(i, j) = stan::math::normal_rng(0, 1, rng); + } + } + + // Questi conti non li passo al proto, attenzione !!! + // out.psi_inverse = out.psi.cwiseInverse().asDiagonal(); + // compute_wood_factors(out.cov_wood, out.cov_logdet, out.lambda, + // out.psi_inverse); + + // Convert to proto + bayesmix::AlgorithmState::ClusterState state; + bayesmix::to_proto(out.mu, state.mutable_fa_state()->mutable_mu()); + bayesmix::to_proto(out.psi, state.mutable_fa_state()->mutable_psi()); + bayesmix::to_proto(out.eta, state.mutable_fa_state()->mutable_eta()); + bayesmix::to_proto(out.lambda, state.mutable_fa_state()->mutable_lambda()); + return std::make_shared(state); + + // MANCA PSI_INVERSE E GLI OUTPUT DA COMPUTE_WOOD_FACTORS !!! BISOGNA + // CAMBIARE IL PROTO +} + +void FAPriorModel::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void FAPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).fa_state(); + hypers->mutilde = bayesmix::to_eigen(hyperscast.mutilde()); + hypers->alpha0 = hyperscast.alpha0(); + hypers->beta = bayesmix::to_eigen(hyperscast.beta()); + hypers->phi = hyperscast.phi(); + hypers->q = hyperscast.q(); + hypers->card = hyperscast.card(); +} + +std::shared_ptr +FAPriorModel::get_hypers_proto() const { + bayesmix::FAPriorDistribution hypers_; + bayesmix::to_proto(hypers->mutilde, hypers_.mutable_mutilde()); + bayesmix::to_proto(hypers->beta, hypers_.mutable_beta()); + hypers_.set_alpha0(hypers->alpha0); + hypers_.set_phi(hypers->phi); + hypers_.set_q(hypers->q); + hypers_.set_card(hypers->card); + + auto out = std::make_shared(); + out->mutable_fa_state()->CopyFrom(hypers_); + return out; +} + +void FAPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->mutilde = bayesmix::to_eigen(prior->fixed_values().mutilde()); + dim = hypers->mutilde.size(); + hypers->beta = bayesmix::to_eigen(prior->fixed_values().beta()); + hypers->phi = prior->fixed_values().phi(); + hypers->alpha0 = prior->fixed_values().alpha0(); + hypers->q = prior->fixed_values().q(); + hypers->card = prior->fixed_values().card(); + + // Check validity + if (dim != hypers->beta.rows()) { + throw std::invalid_argument( + "Hyperparameters dimensions are not consistent"); + } + for (size_t j = 0; j < dim; j++) { + if (hypers->beta[j] <= 0) { + throw std::invalid_argument("Shape parameter must be > 0"); + } + } + if (hypers->alpha0 <= 0) { + throw std::invalid_argument("Scale parameter must be > 0"); + } + if (hypers->phi <= 0) { + throw std::invalid_argument("Diffusion parameter must be > 0"); + } + if (hypers->q <= 0) { + throw std::invalid_argument("Number of factors must be > 0"); + } + if (hypers->card <= 0) { + throw std::invalid_argument("Number of data must be > 0"); + } + } + + else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +/* +// Automatic initialization +if (dim == 0) { + hypers->mutilde = dataset_ptr->colwise().mean(); + dim = hypers->mutilde.size(); +} +if (hypers->beta.size() == 0) { + Eigen::MatrixXd centered = + dataset_ptr->rowwise() - dataset_ptr->colwise().mean(); + auto cov_llt = ((centered.transpose() * centered) / + double(dataset_ptr->rows() - 1.)) + .llt(); + Eigen::MatrixXd precision_matrix( + cov_llt.solve(Eigen::MatrixXd::Identity(dim, dim))); + hypers->beta = + (hypers->alpha0 - 1) * precision_matrix.diagonal().cwiseInverse(); + if (hypers->alpha0 == 1) { + throw std::invalid_argument( + "Scale parameter must be different than 1 when automatic " + "initialization is used"); + } +} +*/ diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h new file mode 100644 index 000000000..cd4d6eea4 --- /dev/null +++ b/src/hierarchies/priors/fa_prior_model.h @@ -0,0 +1,44 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_FA_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_FA_PRIOR_MODEL_H_ + +// #include + +#include +#include +#include + +// #include "algorithm_state.pb.h" +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +class FAPriorModel + : public BasePriorModel { + public: + FAPriorModel() = default; + ~FAPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + std::shared_ptr sample( + bool use_post_hypers) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + unsigned int get_dim() const { return dim; }; + + protected: + std::shared_ptr get_hypers_proto() + const override; + + void initialize_hypers() override; + + unsigned int dim; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_FA_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h index 9043a6435..04dcbd84b 100644 --- a/src/hierarchies/priors/hyperparams.h +++ b/src/hierarchies/priors/hyperparams.h @@ -25,6 +25,12 @@ struct MNIG { double shape, scale; }; +struct FA { + Eigen::VectorXd mutilde, beta; + double phi, alpha0; + unsigned int card, q; +}; + } // namespace Hyperparams #endif // BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_ diff --git a/src/proto/hierarchy_prior.proto b/src/proto/hierarchy_prior.proto index 866189f6a..a34ced9db 100644 --- a/src/proto/hierarchy_prior.proto +++ b/src/proto/hierarchy_prior.proto @@ -102,6 +102,7 @@ message LinRegUniPrior { double phi = 3; double alpha0 = 4; uint32 q = 5; + uint32 card = 6; } From b4359195ef008b5aa2e02ff569aa9bd11a05ea32 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 14:20:41 +0100 Subject: [PATCH 200/317] Ignore old hierarchies (TMP) --- src/.gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 src/.gitignore diff --git a/src/.gitignore b/src/.gitignore new file mode 100644 index 000000000..1b5d07b0b --- /dev/null +++ b/src/.gitignore @@ -0,0 +1 @@ +hierarchies/.old/ From cd5ae366d179d2e660650b36eca2bb0d0b104100 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 14:21:11 +0100 Subject: [PATCH 201/317] Remove old fa_hierarchy --- src/hierarchies/OLD_fa_hierarchy.cc | 299 ---------------------------- src/hierarchies/OLD_fa_hierarchy.h | 134 ------------- 2 files changed, 433 deletions(-) delete mode 100644 src/hierarchies/OLD_fa_hierarchy.cc delete mode 100644 src/hierarchies/OLD_fa_hierarchy.h diff --git a/src/hierarchies/OLD_fa_hierarchy.cc b/src/hierarchies/OLD_fa_hierarchy.cc deleted file mode 100644 index 65e8c716e..000000000 --- a/src/hierarchies/OLD_fa_hierarchy.cc +++ /dev/null @@ -1,299 +0,0 @@ -#include - -#include -#include -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "fa_hierarchy.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/proto_utils.h" -#include "src/utils/rng.h" - -double FAHierarchy::like_lpdf(const Eigen::RowVectorXd& datum) const { - return bayesmix::multi_normal_lpdf_woodbury_chol( - datum, state.mu, state.psi_inverse, state.cov_wood, state.cov_logdet); -} - -FA::State FAHierarchy::draw(const FA::Hyperparams& params) { - auto& rng = bayesmix::Rng::Instance().get(); - FA::State out; - out.mu = params.mutilde; - out.psi = params.beta / (params.alpha0 + 1.); - out.eta = Eigen::MatrixXd::Zero(card, params.q); - out.lambda = Eigen::MatrixXd::Zero(dim, params.q); - - for (size_t j = 0; j < dim; j++) { - out.mu[j] = - stan::math::normal_rng(params.mutilde[j], sqrt(params.phi), rng); - - out.psi[j] = stan::math::inv_gamma_rng(params.alpha0, params.beta[j], rng); - - for (size_t i = 0; i < params.q; i++) { - out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); - } - } - - for (size_t i = 0; i < card; i++) { - for (size_t j = 0; j < params.q; j++) { - out.eta(i, j) = stan::math::normal_rng(0, 1, rng); - } - } - - out.psi_inverse = out.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(out.cov_wood, out.cov_logdet, out.lambda, - out.psi_inverse); - - return out; -} - -void FAHierarchy::initialize_state() { - state.mu = hypers->mutilde; - state.psi = hypers->beta / (hypers->alpha0 + 1.); - state.eta = Eigen::MatrixXd::Zero(card, hypers->q); - state.lambda = Eigen::MatrixXd::Zero(dim, hypers->q); - state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, - state.psi_inverse); -} - -void FAHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mutilde = bayesmix::to_eigen(prior->fixed_values().mutilde()); - dim = hypers->mutilde.size(); - hypers->beta = bayesmix::to_eigen(prior->fixed_values().beta()); - hypers->phi = prior->fixed_values().phi(); - hypers->alpha0 = prior->fixed_values().alpha0(); - hypers->q = prior->fixed_values().q(); - - // Automatic initialization - if (dim == 0) { - hypers->mutilde = dataset_ptr->colwise().mean(); - dim = hypers->mutilde.size(); - } - if (hypers->beta.size() == 0) { - Eigen::MatrixXd centered = - dataset_ptr->rowwise() - dataset_ptr->colwise().mean(); - auto cov_llt = ((centered.transpose() * centered) / - double(dataset_ptr->rows() - 1.)) - .llt(); - Eigen::MatrixXd precision_matrix( - cov_llt.solve(Eigen::MatrixXd::Identity(dim, dim))); - hypers->beta = - (hypers->alpha0 - 1) * precision_matrix.diagonal().cwiseInverse(); - if (hypers->alpha0 == 1) { - throw std::invalid_argument( - "Scale parameter must be different than 1 when automatic " - "initialization is used"); - } - } - // Check validity - if (dim != hypers->beta.rows()) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - for (size_t j = 0; j < dim; j++) { - if (hypers->beta[j] <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - } - if (hypers->alpha0 <= 0) { - throw std::invalid_argument("Scale parameter must be > 0"); - } - if (hypers->phi <= 0) { - throw std::invalid_argument("Diffusion parameter must be > 0"); - } - if (hypers->q <= 0) { - throw std::invalid_argument("Number of factors must be > 0"); - } - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void FAHierarchy::update_hypers( - const std::vector& states) { - auto& rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void FAHierarchy::update_summary_statistics(const Eigen::RowVectorXd& datum, - const bool add) { - if (add) { - data_sum += datum; - } else { - data_sum -= datum; - } -} - -void FAHierarchy::clear_summary_statistics() { - data_sum = Eigen::VectorXd::Zero(dim); -} - -void FAHierarchy::set_state_from_proto( - const google::protobuf::Message& state_) { - auto& statecast = downcast_state(state_); - state.mu = bayesmix::to_eigen(statecast.fa_state().mu()); - state.psi = bayesmix::to_eigen(statecast.fa_state().psi()); - state.eta = bayesmix::to_eigen(statecast.fa_state().eta()); - state.lambda = bayesmix::to_eigen(statecast.fa_state().lambda()); - state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, - state.psi_inverse); - set_card(statecast.cardinality()); -} - -std::shared_ptr -FAHierarchy::get_state_proto() const { - bayesmix::FAState state_; - bayesmix::to_proto(state.mu, state_.mutable_mu()); - bayesmix::to_proto(state.psi, state_.mutable_psi()); - bayesmix::to_proto(state.eta, state_.mutable_eta()); - bayesmix::to_proto(state.lambda, state_.mutable_lambda()); - - auto out = std::make_shared(); - out->mutable_fa_state()->CopyFrom(state_); - return out; -} - -void FAHierarchy::set_hypers_from_proto( - const google::protobuf::Message& hypers_) { - auto& hyperscast = downcast_hypers(hypers_).fa_state(); - hypers->mutilde = bayesmix::to_eigen(hyperscast.mutilde()); - hypers->alpha0 = hyperscast.alpha0(); - hypers->beta = bayesmix::to_eigen(hyperscast.beta()); - hypers->phi = hyperscast.phi(); - hypers->q = hyperscast.q(); -} - -std::shared_ptr -FAHierarchy::get_hypers_proto() const { - bayesmix::FAPriorDistribution hypers_; - bayesmix::to_proto(hypers->mutilde, hypers_.mutable_mutilde()); - bayesmix::to_proto(hypers->beta, hypers_.mutable_beta()); - hypers_.set_alpha0(hypers->alpha0); - hypers_.set_phi(hypers->phi); - hypers_.set_q(hypers->q); - - auto out = std::make_shared(); - out->mutable_fa_state()->CopyFrom(hypers_); - return out; -} - -void FAHierarchy::sample_full_cond(const bool update_params /*= false*/) { - if (this->card == 0) { - // No posterior update possible - sample_prior(); - } else { - sample_eta(); - sample_mu(); - sample_psi(); - sample_lambda(); - } -} - -void FAHierarchy::sample_eta() { - auto& rng = bayesmix::Rng::Instance().get(); - auto sigma_eta_inv_llt = - (Eigen::MatrixXd::Identity(hypers->q, hypers->q) + - state.lambda.transpose() * state.psi_inverse * state.lambda) - .llt(); - if (state.eta.rows() != card) { - state.eta = Eigen::MatrixXd::Zero(card, state.eta.cols()); - } - Eigen::MatrixXd temp_product( - sigma_eta_inv_llt.solve(state.lambda.transpose() * state.psi_inverse)); - auto iterator = cluster_data_idx.begin(); - for (size_t i = 0; i < card; i++, iterator++) { - Eigen::VectorXd tempvector(dataset_ptr->row( - *iterator)); // TODO use slicing when Eigen is updated to v3.4 - state.eta.row(i) = (bayesmix::multi_normal_prec_chol_rng( - temp_product * (tempvector - state.mu), sigma_eta_inv_llt, rng)); - } -} - -void FAHierarchy::sample_mu() { - auto& rng = bayesmix::Rng::Instance().get(); - Eigen::DiagonalMatrix sigma_mu; - - sigma_mu.diagonal() = - (card * state.psi_inverse.diagonal().array() + hypers->phi) - .cwiseInverse(); - - Eigen::VectorXd sum = (state.eta.colwise().sum()); - - Eigen::VectorXd mumean = - sigma_mu * (hypers->phi * hypers->mutilde + - state.psi_inverse * (data_sum - state.lambda * sum)); - - state.mu = bayesmix::multi_normal_diag_rng(mumean, sigma_mu, rng); -} - -void FAHierarchy::sample_lambda() { - auto& rng = bayesmix::Rng::Instance().get(); - - Eigen::MatrixXd temp_etateta(state.eta.transpose() * state.eta); - - for (size_t j = 0; j < dim; j++) { - auto sigma_lambda_inv_llt = - (Eigen::MatrixXd::Identity(hypers->q, hypers->q) + - temp_etateta / state.psi[j]) - .llt(); - Eigen::VectorXd tempsum(card); - const Eigen::VectorXd& data_col = dataset_ptr->col(j); - auto iterator = cluster_data_idx.begin(); - for (size_t i = 0; i < card; i++, iterator++) { - tempsum[i] = data_col( - *iterator); // TODO use slicing when Eigen is updated to v3.4 - } - tempsum = tempsum.array() - state.mu[j]; - tempsum = tempsum.array() / state.psi[j]; - - state.lambda.row(j) = bayesmix::multi_normal_prec_chol_rng( - sigma_lambda_inv_llt.solve(state.eta.transpose() * tempsum), - sigma_lambda_inv_llt, rng); - } -} - -void FAHierarchy::sample_psi() { - auto& rng = bayesmix::Rng::Instance().get(); - - for (size_t j = 0; j < dim; j++) { - double sum = 0; - auto iterator = cluster_data_idx.begin(); - for (size_t i = 0; i < card; i++, iterator++) { - sum += std::pow( - ((*dataset_ptr)(*iterator, j) - - state.mu[j] - // TODO use slicing when Eigen is updated to v3.4 - state.lambda.row(j).dot(state.eta.row(i))), - 2); - } - state.psi[j] = stan::math::inv_gamma_rng(hypers->alpha0 + card / 2, - hypers->beta[j] + sum / 2, rng); - } - state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, - state.psi_inverse); -} - -void FAHierarchy::compute_wood_factors( - Eigen::MatrixXd& cov_wood, double& cov_logdet, - const Eigen::MatrixXd& lambda, - const Eigen::DiagonalMatrix& psi_inverse) { - auto [cov_wood_, cov_logdet_] = - bayesmix::compute_wood_chol_and_logdet(psi_inverse, lambda); - cov_logdet = cov_logdet_; - cov_wood = cov_wood_; -} diff --git a/src/hierarchies/OLD_fa_hierarchy.h b/src/hierarchies/OLD_fa_hierarchy.h deleted file mode 100644 index 8b6da31d4..000000000 --- a/src/hierarchies/OLD_fa_hierarchy.h +++ /dev/null @@ -1,134 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "base_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" -#include "src/utils/distributions.h" - -//! Mixture of Factor Analysers hierarchy for multivariate data. - -//! This class represents a hierarchical model where data are distributed - -namespace FA { -//! Custom container for State values -struct State { - Eigen::VectorXd mu, psi; - Eigen::MatrixXd eta, lambda, cov_wood; - Eigen::DiagonalMatrix psi_inverse; - double cov_logdet; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - Eigen::VectorXd mutilde, beta; - double phi, alpha0; - unsigned int q; -}; - -}; // namespace FA - -class FAHierarchy : public BaseHierarchy { - public: - FAHierarchy() = default; - ~FAHierarchy() = default; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector& - states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - FA::State draw(const FA::Hyperparams& params); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::FA; - } - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message& state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message& hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return true; } - - //! Saves posterior hyperparameters to the corresponding class member - void save_posterior_hypers() { - // No Hyperprior present in the hierarchy - } - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(const bool update_params = false) override; - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd& datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd& datum, - const bool add) override; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Gibbs sampling step for state variable eta - void sample_eta(); - - //! Gibbs sampling step for state variable mu - void sample_mu(); - - //! Gibbs sampling step for state variable psi - void sample_psi(); - - //! Gibbs sampling step for state variable lambda - void sample_lambda(); - - //! Helper function to compute factors needed for likelihood evaluation - void compute_wood_factors( - Eigen::MatrixXd& cov_wood, double& cov_logdet, - const Eigen::MatrixXd& lambda, - const Eigen::DiagonalMatrix& psi_inverse); - - //! Sum of data points currently belonging to the cluster - Eigen::VectorXd data_sum; - - //! Number of variables for each datum - size_t dim; -}; - -#endif // BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ From fb7de86c2bd7abbc55130f0874f9fb3644e0e555 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:42:44 +0100 Subject: [PATCH 202/317] Revert changes --- src/proto/hierarchy_prior.proto | 1 - 1 file changed, 1 deletion(-) diff --git a/src/proto/hierarchy_prior.proto b/src/proto/hierarchy_prior.proto index a34ced9db..866189f6a 100644 --- a/src/proto/hierarchy_prior.proto +++ b/src/proto/hierarchy_prior.proto @@ -102,7 +102,6 @@ message LinRegUniPrior { double phi = 3; double alpha0 = 4; uint32 q = 5; - uint32 card = 6; } From 8d023ae05586cad95e25d67b9aa8eb8d84309eb3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:43:10 +0100 Subject: [PATCH 203/317] Add HierarchyId::FA --- src/proto/hierarchy_id.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/proto/hierarchy_id.proto b/src/proto/hierarchy_id.proto index 6ab62ab93..dcd870592 100644 --- a/src/proto/hierarchy_id.proto +++ b/src/proto/hierarchy_id.proto @@ -12,5 +12,5 @@ enum HierarchyId { LinRegUni = 3; // Linear Regression (univariate response) NNxIG = 4; // Normal - Normal x Inverse Gamma LapNIG = 5; // Laplace - Normal Inverse Gamma - // FA = 6; // Factor Analysers + FA = 6; // Factor Analysers } From f4b8070bc6988023225dcd3f17ae22fa8fdc6ddb Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:44:04 +0100 Subject: [PATCH 204/317] Add fa_hierarchy.h target source --- src/hierarchies/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index e8f5e427e..a66024151 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -9,7 +9,7 @@ target_sources(bayesmix # conjugate_hierarchy.h lin_reg_uni_hierarchy.h # lin_reg_uni_hierarchy.cc - # fa_hierarchy.h + fa_hierarchy.h # fa_hierarchy.cc lapnig_hierarchy.h # lapnig_hierarchy.cc From 2c44c438eb2e846633559823bd0d4f4c36744e03 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:44:35 +0100 Subject: [PATCH 205/317] Add FAUpdater --- src/hierarchies/updaters/CMakeLists.txt | 2 + src/hierarchies/updaters/fa_updater.cc | 139 ++++++++++++++++++++++++ src/hierarchies/updaters/fa_updater.h | 37 +++++++ 3 files changed, 178 insertions(+) create mode 100644 src/hierarchies/updaters/fa_updater.cc create mode 100644 src/hierarchies/updaters/fa_updater.h diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 1a3dd4d82..3a35d9f3e 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -14,4 +14,6 @@ target_sources(bayesmix PUBLIC nnw_updater.cc mnig_updater.h mnig_updater.cc + fa_updater.h + fa_updater.cc ) diff --git a/src/hierarchies/updaters/fa_updater.cc b/src/hierarchies/updaters/fa_updater.cc new file mode 100644 index 000000000..3f8f81694 --- /dev/null +++ b/src/hierarchies/updaters/fa_updater.cc @@ -0,0 +1,139 @@ +#include "fa_updater.h" + +#include "src/utils/distributions.h" + +void FAUpdater::draw(AbstractLikelihood& like, AbstractPriorModel& prior, + bool update_params) { + // Likelihood and PriorModel downcast + auto& likecast = static_cast(like); + auto& priorcast = static_cast(prior); + // Sample from the full conditional of the fa hierarchy + bool set_card = true, use_post_hypers = true; + if (likecast.get_card() == 0) { + likecast.set_state_from_proto(*priorcast.sample(!use_post_hypers), + !set_card); + } else { + // Get state and hypers + State::FA new_state = likecast.get_state(); + Hyperparams::FA hypers = priorcast.get_hypers(); + // Gibbs update + sample_eta(new_state, hypers, likecast); + sample_mu(new_state, hypers, likecast); + // sample_psi(new_state, hypers, likecast.get_dataset(), + // likecast.get_data_idx(), priorcast.get_dim()); sample_lambda(new_state, + // hypers, likecast.get_dataset(), likecast.get_data_idx(), + // priorcast.get_dim()); Eigen2Proto conversion + bayesmix::AlgorithmState::ClusterState new_state_proto; + bayesmix::to_proto(new_state.eta, + new_state_proto.mutable_fa_state()->mutable_eta()); + bayesmix::to_proto(new_state.mu, + new_state_proto.mutable_fa_state()->mutable_mu()); + bayesmix::to_proto(new_state.psi, + new_state_proto.mutable_fa_state()->mutable_psi()); + bayesmix::to_proto(new_state.lambda, + new_state_proto.mutable_fa_state()->mutable_lambda()); + likecast.set_state_from_proto(new_state_proto, !set_card); + } +} + +void FAUpdater::sample_eta(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like) { + // Random Seed + auto& rng = bayesmix::Rng::Instance().get(); + // Get required information + auto dataset_ptr = like.get_dataset(); + auto cluster_data_idx = like.get_data_idx(); + unsigned int card = like.get_card(); + // eta update + auto sigma_eta_inv_llt = + (Eigen::MatrixXd::Identity(hypers.q, hypers.q) + + state.lambda.transpose() * state.psi_inverse * state.lambda) + .llt(); + if (state.eta.rows() != card) { + state.eta = Eigen::MatrixXd::Zero(card, state.eta.cols()); + } + Eigen::MatrixXd temp_product( + sigma_eta_inv_llt.solve(state.lambda.transpose() * state.psi_inverse)); + auto iterator = cluster_data_idx.begin(); + for (size_t i = 0; i < card; i++, iterator++) { + Eigen::VectorXd tempvector(dataset_ptr->row( + *iterator)); // TODO use slicing when Eigen is updated to v3.4 + state.eta.row(i) = (bayesmix::multi_normal_prec_chol_rng( + temp_product * (tempvector - state.mu), sigma_eta_inv_llt, rng)); + } +} + +void FAUpdater::sample_mu(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like) { + // Random seed + auto& rng = bayesmix::Rng::Instance().get(); + // Get required information + Eigen::VectorXd data_sum = like.get_data_sum(); + unsigned int card = like.get_card(); + // mu update + Eigen::DiagonalMatrix sigma_mu; + sigma_mu.diagonal() = + (card * state.psi_inverse.diagonal().array() + hypers.phi) + .cwiseInverse(); + Eigen::VectorXd sum = (state.eta.colwise().sum()); + Eigen::VectorXd mumean = + sigma_mu * (hypers.phi * hypers.mutilde + + state.psi_inverse * (data_sum - state.lambda * sum)); + state.mu = bayesmix::multi_normal_diag_rng(mumean, sigma_mu, rng); +} + +void FAUpdater::sample_lambda(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like) { + // Random seed + auto& rng = bayesmix::Rng::Instance().get(); + // Getting required information + unsigned int dim = like.get_dim(); + unsigned int card = like.get_card(); + auto dataset_ptr = like.get_dataset(); + auto cluster_data_idx = like.get_data_idx(); + // lambda update + Eigen::MatrixXd temp_etateta(state.eta.transpose() * state.eta); + for (size_t j = 0; j < dim; j++) { + auto sigma_lambda_inv_llt = + (Eigen::MatrixXd::Identity(hypers.q, hypers.q) + + temp_etateta / state.psi[j]) + .llt(); + Eigen::VectorXd tempsum(card); + const Eigen::VectorXd& data_col = dataset_ptr->col(j); + auto iterator = cluster_data_idx.begin(); + for (size_t i = 0; i < card; i++, iterator++) { + tempsum[i] = data_col( + *iterator); // TODO use slicing when Eigen is updated to v3.4 + } + tempsum = tempsum.array() - state.mu[j]; + tempsum = tempsum.array() / state.psi[j]; + state.lambda.row(j) = bayesmix::multi_normal_prec_chol_rng( + sigma_lambda_inv_llt.solve(state.eta.transpose() * tempsum), + sigma_lambda_inv_llt, rng); + } +} + +void FAUpdater::sample_psi(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like) { + // Random seed + auto& rng = bayesmix::Rng::Instance().get(); + // Getting required information + auto dataset_ptr = like.get_dataset(); + auto cluster_data_idx = like.get_data_idx(); + unsigned int dim = like.get_dim(); + unsigned int card = like.get_card(); + // psi update + for (size_t j = 0; j < dim; j++) { + double sum = 0; + auto iterator = cluster_data_idx.begin(); + for (size_t i = 0; i < card; i++, iterator++) { + sum += std::pow( + ((*dataset_ptr)(*iterator, j) - + state.mu[j] - // TODO use slicing when Eigen is updated to v3.4 + state.lambda.row(j).dot(state.eta.row(i))), + 2); + } + state.psi[j] = stan::math::inv_gamma_rng(hypers.alpha0 + card / 2, + hypers.beta[j] + sum / 2, rng); + } +} diff --git a/src/hierarchies/updaters/fa_updater.h b/src/hierarchies/updaters/fa_updater.h new file mode 100644 index 000000000..b1ffcb3ae --- /dev/null +++ b/src/hierarchies/updaters/fa_updater.h @@ -0,0 +1,37 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_FA_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_FA_UPDATER_H_ + +#include "abstract_updater.h" +#include "src/hierarchies/likelihoods/fa_likelihood.h" +#include "src/hierarchies/likelihoods/states/includes.h" +#include "src/hierarchies/priors/fa_prior_model.h" +#include "src/hierarchies/priors/hyperparams.h" +#include "src/utils/proto_utils.h" + +class FAUpdater : public AbstractUpdater { + public: + FAUpdater() = default; + ~FAUpdater() = default; + void draw(AbstractLikelihood& like, AbstractPriorModel& prior, + bool update_params) override; + + protected: + void sample_eta(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like); + void sample_mu(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like); + void sample_lambda(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like); + void sample_psi(State::FA& state, const Hyperparams::FA& hypers, + const FALikelihood& like); + // void sample_eta(State::FA & state, const Hyperparams::FA & hypers, const + // Eigen::MatrixXd * dataset_ptr, const std::set & cluster_data_idx); + // void sample_mu(State::FA & state, const Hyperparams::FA & hypers, const + // Eigen::VectorXd & data_sum); void sample_lambda(State::FA & state, const + // Hyperparams::FA & hypers, const Eigen::MatrixXd * dataset_ptr, const + // std::set & cluster_data_idx, size_t dim); void sample_psi(State::FA & + // state, const Hyperparams::FA & hypers, const Eigen::MatrixXd * + // dataset_ptr, const std::set & cluster_data_idx, size_t dim); +}; + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_FA_UPDATER_H_ From 8056fdef4b7637c795465674a03d4d2d8648b47f Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:45:21 +0100 Subject: [PATCH 206/317] fa_hierarchy tests are successful --- test/hierarchies.cc | 135 ++++++++++++++++++++++---------------------- 1 file changed, 67 insertions(+), 68 deletions(-) diff --git a/test/hierarchies.cc b/test/hierarchies.cc index 751add880..ed9b09e63 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -5,8 +5,8 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" +#include "src/hierarchies/fa_hierarchy.h" #include "src/hierarchies/lin_reg_uni_hierarchy.h" -// #include "src/hierarchies/fa_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" #include "src/hierarchies/nnw_hierarchy.h" #include "src/includes.h" @@ -300,39 +300,39 @@ TEST(nnxig_hierarchy, sample_given_data) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -// TEST(fahierarchy, draw) { -// auto hier = std::make_shared(); -// bayesmix::FAPrior prior; -// Eigen::VectorXd mutilde(4); -// mutilde << 3.0, 3.0, 4.0, 1.0; -// bayesmix::Vector mutilde_proto; -// bayesmix::to_proto(mutilde, &mutilde_proto); -// int q = 2; -// double phi = 1.0; -// double alpha0 = 5.0; -// Eigen::VectorXd beta(4); -// beta << 3.0, 3.0, 2.0, 2.1; -// bayesmix::Vector beta_proto; -// bayesmix::to_proto(beta, &beta_proto); -// *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; -// prior.mutable_fixed_values()->set_phi(phi); -// prior.mutable_fixed_values()->set_alpha0(alpha0); -// prior.mutable_fixed_values()->set_q(q); -// *prior.mutable_fixed_values()->mutable_beta() = beta_proto; -// hier->get_mutable_prior()->CopyFrom(prior); -// hier->initialize(); +TEST(fa_hierarchy, draw) { + auto hier = std::make_shared(); + bayesmix::FAPrior prior; + Eigen::VectorXd mutilde(4); + mutilde << 3.0, 3.0, 4.0, 1.0; + bayesmix::Vector mutilde_proto; + bayesmix::to_proto(mutilde, &mutilde_proto); + int q = 2; + double phi = 1.0; + double alpha0 = 5.0; + Eigen::VectorXd beta(4); + beta << 3.0, 3.0, 2.0, 2.1; + bayesmix::Vector beta_proto; + bayesmix::to_proto(beta, &beta_proto); + *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; + prior.mutable_fixed_values()->set_phi(phi); + prior.mutable_fixed_values()->set_alpha0(alpha0); + prior.mutable_fixed_values()->set_q(q); + *prior.mutable_fixed_values()->mutable_beta() = beta_proto; + hier->get_mutable_prior()->CopyFrom(prior); + hier->initialize(); -// auto hier2 = hier->clone(); -// hier2->sample_prior(); + auto hier2 = hier->clone(); + hier2->sample_prior(); -// bayesmix::AlgorithmState out; -// bayesmix::AlgorithmState::ClusterState* clusval = -// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 -// = out.add_cluster_states(); hier->write_state_to_proto(clusval); -// hier2->write_state_to_proto(clusval2); + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); -// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -// } + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} // TEST(fahierarchy, draw_auto) { // auto hier = std::make_shared(); @@ -372,42 +372,41 @@ TEST(nnxig_hierarchy, sample_given_data) { // << clusval->DebugString() << clusval2->DebugString(); // } -// TEST(fahierarchy, sample_given_data) { -// auto hier = std::make_shared(); -// bayesmix::FAPrior prior; -// Eigen::VectorXd mutilde(4); -// mutilde << 3.0, 3.0, 4.0, 1.0; -// bayesmix::Vector mutilde_proto; -// bayesmix::to_proto(mutilde, &mutilde_proto); -// int q = 2; -// double phi = 1.0; -// double alpha0 = 5.0; -// Eigen::VectorXd beta(4); -// beta << 3.0, 3.0, 2.0, 2.1; -// bayesmix::Vector beta_proto; -// bayesmix::to_proto(beta, &beta_proto); -// *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; -// prior.mutable_fixed_values()->set_phi(phi); -// prior.mutable_fixed_values()->set_alpha0(alpha0); -// prior.mutable_fixed_values()->set_q(q); -// *prior.mutable_fixed_values()->mutable_beta() = beta_proto; -// hier->get_mutable_prior()->CopyFrom(prior); -// Eigen::MatrixXd dataset(5, 4); -// dataset << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, -// 19, -// 20; -// hier->set_dataset(&dataset); -// hier->initialize(); +TEST(fa_hierarchy, sample_given_data) { + auto hier = std::make_shared(); + bayesmix::FAPrior prior; + Eigen::VectorXd mutilde(4); + mutilde << 3.0, 3.0, 4.0, 1.0; + bayesmix::Vector mutilde_proto; + bayesmix::to_proto(mutilde, &mutilde_proto); + int q = 2; + double phi = 1.0; + double alpha0 = 5.0; + Eigen::VectorXd beta(4); + beta << 3.0, 3.0, 2.0, 2.1; + bayesmix::Vector beta_proto; + bayesmix::to_proto(beta, &beta_proto); + *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; + prior.mutable_fixed_values()->set_phi(phi); + prior.mutable_fixed_values()->set_alpha0(alpha0); + prior.mutable_fixed_values()->set_q(q); + *prior.mutable_fixed_values()->mutable_beta() = beta_proto; + hier->get_mutable_prior()->CopyFrom(prior); + Eigen::MatrixXd dataset(5, 4); + dataset << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20; + hier->set_dataset(&dataset); + hier->initialize(); -// auto hier2 = hier->clone(); -// hier2->add_datum(0, dataset.row(0), false); -// hier2->add_datum(1, dataset.row(1), false); -// hier2->sample_full_cond(); -// bayesmix::AlgorithmState out; -// bayesmix::AlgorithmState::ClusterState* clusval = -// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 -// = out.add_cluster_states(); hier->write_state_to_proto(clusval); -// hier2->write_state_to_proto(clusval2); + auto hier2 = hier->clone(); + hier2->add_datum(0, dataset.row(0), false); + hier2->add_datum(1, dataset.row(1), false); + hier2->sample_full_cond(); + bayesmix::AlgorithmState out; + bayesmix::AlgorithmState::ClusterState* clusval = out.add_cluster_states(); + bayesmix::AlgorithmState::ClusterState* clusval2 = out.add_cluster_states(); + hier->write_state_to_proto(clusval); + hier2->write_state_to_proto(clusval2); -// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); -// } + ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} From 0214c124161261f360f056a52da8b7f3c9520f73 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:45:47 +0100 Subject: [PATCH 207/317] Code fixes --- src/hierarchies/priors/fa_prior_model.cc | 42 ++++++++++++------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index ccc541f7a..43abee380 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -3,10 +3,10 @@ double FAPriorModel::lpdf(const google::protobuf::Message &state_) { // Downcast state auto &state = downcast_state(state_).fa_state(); - // Proto to Eigen conversion + // Proto2Eigen conversion Eigen::VectorXd mu = bayesmix::to_eigen(state.mu()); Eigen::VectorXd psi = bayesmix::to_eigen(state.psi()); - Eigen::MatrixXd eta = bayesmix::to_eigen(state.eta()); + // Eigen::MatrixXd eta = bayesmix::to_eigen(state.eta()); Eigen::MatrixXd lambda = bayesmix::to_eigen(state.lambda()); // Initialize lpdf value double target = 0.; @@ -20,11 +20,11 @@ double FAPriorModel::lpdf(const google::protobuf::Message &state_) { target += stan::math::normal_lpdf(lambda(j, i), 0, 1); } } - for (size_t i = 0; i < eta.rows(); i++) { - for (size_t j = 0; j < hypers->q; j++) { - target += stan::math::normal_lpdf(eta(i, j), 0, 1); - } - } + // for (size_t i = 0; i < eta.rows(); i++) { + // for (size_t j = 0; j < hypers->q; j++) { + // target += stan::math::normal_lpdf(eta(i, j), 0, 1); + // } + // } // Return lpdf contribution return target; } @@ -42,7 +42,7 @@ std::shared_ptr FAPriorModel::sample( State::FA out; out.mu = params.mutilde; out.psi = params.beta / (params.alpha0 + 1.); - out.eta = Eigen::MatrixXd::Zero(params.card, params.q); + // out.eta = Eigen::MatrixXd::Zero(params.card, params.q); out.lambda = Eigen::MatrixXd::Zero(dim, params.q); for (size_t j = 0; j < dim; j++) { out.mu[j] = @@ -54,22 +54,22 @@ std::shared_ptr FAPriorModel::sample( out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); } } - for (size_t i = 0; i < params.card; i++) { - for (size_t j = 0; j < params.q; j++) { - out.eta(i, j) = stan::math::normal_rng(0, 1, rng); - } - } + // for (size_t i = 0; i < params.card; i++) { + // for (size_t j = 0; j < params.q; j++) { + // out.eta(i, j) = stan::math::normal_rng(0, 1, rng); + // } + // } // Questi conti non li passo al proto, attenzione !!! // out.psi_inverse = out.psi.cwiseInverse().asDiagonal(); // compute_wood_factors(out.cov_wood, out.cov_logdet, out.lambda, // out.psi_inverse); - // Convert to proto + // Eigen2Proto conversion bayesmix::AlgorithmState::ClusterState state; bayesmix::to_proto(out.mu, state.mutable_fa_state()->mutable_mu()); bayesmix::to_proto(out.psi, state.mutable_fa_state()->mutable_psi()); - bayesmix::to_proto(out.eta, state.mutable_fa_state()->mutable_eta()); + // bayesmix::to_proto(out.eta, state.mutable_fa_state()->mutable_eta()); bayesmix::to_proto(out.lambda, state.mutable_fa_state()->mutable_lambda()); return std::make_shared(state); @@ -95,7 +95,7 @@ void FAPriorModel::set_hypers_from_proto( hypers->beta = bayesmix::to_eigen(hyperscast.beta()); hypers->phi = hyperscast.phi(); hypers->q = hyperscast.q(); - hypers->card = hyperscast.card(); + // hypers->card = hyperscast.card(); } std::shared_ptr @@ -106,7 +106,7 @@ FAPriorModel::get_hypers_proto() const { hypers_.set_alpha0(hypers->alpha0); hypers_.set_phi(hypers->phi); hypers_.set_q(hypers->q); - hypers_.set_card(hypers->card); + // hypers_.set_card(hypers->card); auto out = std::make_shared(); out->mutable_fa_state()->CopyFrom(hypers_); @@ -122,7 +122,7 @@ void FAPriorModel::initialize_hypers() { hypers->phi = prior->fixed_values().phi(); hypers->alpha0 = prior->fixed_values().alpha0(); hypers->q = prior->fixed_values().q(); - hypers->card = prior->fixed_values().card(); + // hypers->card = prior->fixed_values().card(); // Check validity if (dim != hypers->beta.rows()) { @@ -143,9 +143,9 @@ void FAPriorModel::initialize_hypers() { if (hypers->q <= 0) { throw std::invalid_argument("Number of factors must be > 0"); } - if (hypers->card <= 0) { - throw std::invalid_argument("Number of data must be > 0"); - } + // if (hypers->card < 0) { + // throw std::invalid_argument("Number of data must be >= 0"); + // } } else { From 8a38363ce120eff93d337bdd9fcd87adb6232bf7 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:45:59 +0100 Subject: [PATCH 208/317] Code fixes --- src/hierarchies/priors/base_prior_model.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 8c43d39fb..cb3f71705 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -74,13 +74,13 @@ class BasePriorModel : public AbstractPriorModel { } //! Returns an independent, data-less copy of this object - virtual std::shared_ptr clone() const override; + std::shared_ptr clone() const override; //! Returns an independent, data-less deep copy of this object - virtual std::shared_ptr deep_clone() const override; + std::shared_ptr deep_clone() const override; //! Returns a pointer to the Protobuf message of the prior of this cluster - virtual google::protobuf::Message *get_mutable_prior() override; + google::protobuf::Message *get_mutable_prior() override; //! Returns the struct of the current prior hyperparameters HyperParams get_hypers() const { return *hypers; } From 8980003d78315a443dea68ee7445d7ae62b79cbb Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:46:16 +0100 Subject: [PATCH 209/317] Code fixes --- src/hierarchies/likelihoods/base_likelihood.h | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 0ebff05fe..0a471a5d6 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -70,7 +70,7 @@ class BaseLikelihood : public AbstractLikelihood { ~BaseLikelihood() = default; //! Returns an independent, data-less copy of this object - virtual std::shared_ptr clone() const override { + std::shared_ptr clone() const override { auto out = std::make_shared(static_cast(*this)); out->clear_data(); out->clear_summary_statistics(); @@ -113,9 +113,9 @@ class BaseLikelihood : public AbstractLikelihood { //! @param data Grid of points (by row) which are to be evaluated //! @param covariates (Optional) covariate vectors associated to data //! @return The evaluation of the lpdf - virtual Eigen::VectorXd lpdf_grid(const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = - Eigen::MatrixXd(0, 0)) const override; + Eigen::VectorXd lpdf_grid(const Eigen::MatrixXd &data, + const Eigen::MatrixXd &covariates = + Eigen::MatrixXd(0, 0)) const override; //! Returns the current cardinality of the cluster int get_card() const { return card; } @@ -149,10 +149,13 @@ class BaseLikelihood : public AbstractLikelihood { } //! Sets the (pointer to) the dataset in the cluster - void set_dataset(const Eigen::MatrixXd *const dataset) { + void set_dataset(const Eigen::MatrixXd *const dataset) override { dataset_ptr = dataset; } + //! Returns the (pointer to) the dataset in the cluster + const Eigen::MatrixXd *get_dataset() const { return dataset_ptr; } + //! Adds a datum and its index to the likelihood void add_datum( const int id, const Eigen::RowVectorXd &datum, From d3780ce250818f7ff8aa1e8dd9cd62bcf0fe0dec Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:46:52 +0100 Subject: [PATCH 210/317] made compute_wood_factors() public method --- src/hierarchies/likelihoods/fa_likelihood.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h index d9cd5a77e..83a9c2753 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.h +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -31,14 +31,15 @@ class FALikelihood : public BaseLikelihood { std::shared_ptr get_state_proto() const override; - protected: - double compute_lpdf(const Eigen::RowVectorXd& datum) const override; - void update_sum_stats(const Eigen::RowVectorXd& datum, bool add) override; void compute_wood_factors( Eigen::MatrixXd& cov_wood, double& cov_logdet, const Eigen::MatrixXd& lambda, const Eigen::DiagonalMatrix& psi_inverse); + protected: + double compute_lpdf(const Eigen::RowVectorXd& datum) const override; + void update_sum_stats(const Eigen::RowVectorXd& datum, bool add) override; + unsigned int dim; Eigen::VectorXd data_sum; }; From d13a2ccbc3653ffdd0f1a668320b6e76c7e83d67 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:47:02 +0100 Subject: [PATCH 211/317] Add FAHierarchy --- src/hierarchies/load_hierarchies.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index eb87cac63..21ec46ae5 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -5,7 +5,7 @@ #include #include "abstract_hierarchy.h" -// #include "fa_hierarchy.h" +#include "fa_hierarchy.h" #include "hierarchy_id.pb.h" #include "lapnig_hierarchy.h" #include "lin_reg_uni_hierarchy.h" @@ -37,9 +37,9 @@ __attribute__((constructor)) static void load_hierarchies() { Builder LinRegUnibuilder = []() { return std::make_shared(); }; - // Builder FAbuilder = []() { - // return std::make_shared(); - // }; + Builder FAbuilder = []() { + return std::make_shared(); + }; Builder LapNIGbuilder = []() { return std::make_shared(); }; @@ -48,7 +48,7 @@ __attribute__((constructor)) static void load_hierarchies() { factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); factory.add_builder(NNWHierarchy().get_id(), NNWbuilder); factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); - // factory.add_builder(FAHierarchy().get_id(), FAbuilder); + factory.add_builder(FAHierarchy().get_id(), FAbuilder); factory.add_builder(LapNIGHierarchy().get_id(), LapNIGbuilder); } From 1b8f809f3967e8a51df353d4f17c49b871bb8dbc Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:47:20 +0100 Subject: [PATCH 212/317] Add re-factored FAHierarchy --- src/hierarchies/fa_hierarchy.h | 52 ++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 src/hierarchies/fa_hierarchy.h diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h new file mode 100644 index 000000000..f68ba3fdd --- /dev/null +++ b/src/hierarchies/fa_hierarchy.h @@ -0,0 +1,52 @@ +#ifndef BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ + +// #include + +// #include +// #include +// #include + +// #include "algorithm_state.pb.h" +// #include "conjugate_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "src/utils/distributions.h" +// #include "hierarchy_prior.pb.h" + +#include "base_hierarchy.h" +#include "likelihoods/fa_likelihood.h" +#include "priors/fa_prior_model.h" +#include "updaters/fa_updater.h" + +class FAHierarchy + : public BaseHierarchy { + public: + FAHierarchy() = default; + ~FAHierarchy() = default; + + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::FA; + } + + void set_default_updater() { updater = std::make_shared(); } + + void initialize_state() override { + // Initialize likelihood dimension to prior one + like->set_dim(prior->get_dim()); + // Get hypers and data dimension + auto hypers = prior->get_hypers(); + unsigned int dim = like->get_dim(); + // Initialize likelihood state + State::FA state; + state.mu = hypers.mutilde; + state.psi = hypers.beta / (hypers.alpha0 + 1.); + state.eta = Eigen::MatrixXd::Zero(hypers.card, hypers.q); + state.lambda = Eigen::MatrixXd::Zero(dim, hypers.q); + state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); + like->set_state(state); + like->compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, + state.psi_inverse); + } +}; + +#endif // BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ From e283e654f7092bb760990654be9dff6b403bb502 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:52:37 +0100 Subject: [PATCH 213/317] Delete src/hierarchies/.old directory --- src/hierarchies/.old/base_hierarchy.h | 321 --------------- src/hierarchies/.old/conjugate_hierarchy.h | 206 ---------- src/hierarchies/.old/lapnig_hierarchy.cc | 215 ---------- src/hierarchies/.old/lapnig_hierarchy.h | 146 ------- src/hierarchies/.old/lin_reg_uni_hierarchy.cc | 166 -------- src/hierarchies/.old/lin_reg_uni_hierarchy.h | 144 ------- src/hierarchies/.old/nnig_hierarchy.cc | 267 ------------- src/hierarchies/.old/nnig_hierarchy.h | 122 ------ src/hierarchies/.old/nnw_hierarchy.cc | 373 ------------------ src/hierarchies/.old/nnw_hierarchy.h | 168 -------- src/hierarchies/.old/nnxig_hierarchy.cc | 152 ------- src/hierarchies/.old/nnxig_hierarchy.h | 120 ------ 12 files changed, 2400 deletions(-) delete mode 100644 src/hierarchies/.old/base_hierarchy.h delete mode 100644 src/hierarchies/.old/conjugate_hierarchy.h delete mode 100644 src/hierarchies/.old/lapnig_hierarchy.cc delete mode 100644 src/hierarchies/.old/lapnig_hierarchy.h delete mode 100644 src/hierarchies/.old/lin_reg_uni_hierarchy.cc delete mode 100644 src/hierarchies/.old/lin_reg_uni_hierarchy.h delete mode 100644 src/hierarchies/.old/nnig_hierarchy.cc delete mode 100644 src/hierarchies/.old/nnig_hierarchy.h delete mode 100644 src/hierarchies/.old/nnw_hierarchy.cc delete mode 100644 src/hierarchies/.old/nnw_hierarchy.h delete mode 100644 src/hierarchies/.old/nnxig_hierarchy.cc delete mode 100644 src/hierarchies/.old/nnxig_hierarchy.h diff --git a/src/hierarchies/.old/base_hierarchy.h b/src/hierarchies/.old/base_hierarchy.h deleted file mode 100644 index e97c92080..000000000 --- a/src/hierarchies/.old/base_hierarchy.h +++ /dev/null @@ -1,321 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ - -#include - -#include -#include -#include -#include -#include - -#include "abstract_hierarchy.h" -#include "algorithm_state.pb.h" -#include "hierarchy_id.pb.h" -#include "src/utils/rng.h" - -//! Base template class for a hierarchy object. - -//! This class is a templatized version of, and derived from, the -//! `AbstractHierarchy` class, and the second stage of the curiously recurring -//! template pattern for `Hierarchy` objects (please see the docs of the parent -//! class for further information). It includes class members and some more -//! functions which could not be implemented in the non-templatized abstract -//! class. -//! See, for instance, `ConjugateHierarchy` and `NNIGHierarchy` to better -//! understand the CRTP patterns. - -//! @tparam Derived Name of the implemented derived class -//! @tparam State Class name of the container for state values -//! @tparam Hyperparams Class name of the container for hyperprior parameters -//! @tparam Prior Class name of the container for prior parameters - -template -class BaseHierarchy : public AbstractHierarchy { - public: - BaseHierarchy() = default; - ~BaseHierarchy() = default; - - //! Returns an independent, data-less copy of this object - virtual std::shared_ptr clone() const override { - auto out = std::make_shared(static_cast(*this)); - out->clear_data(); - out->clear_summary_statistics(); - return out; - } - - //! Evaluates the log-likelihood of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - virtual Eigen::VectorXd like_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Generates new state values from the centering prior distribution - void sample_prior() override { - state = static_cast(this)->draw(*hypers); - } - - //! Overloaded version of sample_full_cond(bool), mainly used for debugging - virtual void sample_full_cond( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override; - - //! Returns the current cardinality of the cluster - int get_card() const override { return card; } - - //! Returns the logarithm of the current cardinality of the cluster - double get_log_card() const override { return log_card; } - - //! Returns the indexes of data points belonging to this cluster - std::set get_data_idx() const override { return cluster_data_idx; } - - //! Returns a pointer to the Protobuf message of the prior of this cluster - virtual google::protobuf::Message *get_mutable_prior() override { - if (prior == nullptr) { - create_empty_prior(); - } - return prior.get(); - } - - //! Writes current state to a Protobuf message by pointer - void write_state_to_proto(google::protobuf::Message *out) const override; - - //! Writes current values of the hyperparameters to a Protobuf message by - //! pointer - void write_hypers_to_proto(google::protobuf::Message *out) const override; - - //! Returns the struct of the current state - State get_state() const { return state; } - - //! Returns the struct of the current prior hyperparameters - Hyperparams get_hypers() const { return *hypers; } - - //! Returns the struct of the current posterior hyperparameters - Hyperparams get_posterior_hypers() const { return posterior_hypers; } - - //! Adds a datum and its index to the hierarchy - void add_datum( - const int id, const Eigen::RowVectorXd &datum, - const bool update_params = false, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; - - //! Removes a datum and its index from the hierarchy - void remove_datum( - const int id, const Eigen::RowVectorXd &datum, - const bool update_params = false, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override; - - //! Main function that initializes members to appropriate values - void initialize() override { - hypers = std::make_shared(); - check_prior_is_set(); - initialize_hypers(); - initialize_state(); - posterior_hypers = *hypers; - clear_data(); - clear_summary_statistics(); - } - - protected: - //! Raises an error if the prior pointer is not initialized - void check_prior_is_set() const { - if (prior == nullptr) { - throw std::invalid_argument("Hierarchy prior was not provided"); - } - } - - //! Re-initializes the prior of the hierarchy to a newly created object - void create_empty_prior() { prior.reset(new Prior); } - - //! Sets the cardinality of the cluster - void set_card(const int card_) { - card = card_; - log_card = (card_ == 0) ? stan::math::NEGATIVE_INFTY : std::log(card_); - } - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - virtual std::shared_ptr - get_state_proto() const = 0; - - //! Initializes state parameters to appropriate values - virtual void initialize_state() = 0; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - virtual std::shared_ptr - get_hypers_proto() const = 0; - - //! Initializes hierarchy hyperparameters to appropriate values - virtual void initialize_hypers() = 0; - - //! Resets cardinality and indexes of data in this cluster - void clear_data() { - set_card(0); - cluster_data_idx = std::set(); - } - - virtual void clear_summary_statistics() = 0; - - //! Down-casts the given generic proto message to a ClusterState proto - bayesmix::AlgorithmState::ClusterState *downcast_state( - google::protobuf::Message *state_) const { - return google::protobuf::internal::down_cast< - bayesmix::AlgorithmState::ClusterState *>(state_); - } - - //! Down-casts the given generic proto message to a ClusterState proto - const bayesmix::AlgorithmState::ClusterState &downcast_state( - const google::protobuf::Message &state_) const { - return google::protobuf::internal::down_cast< - const bayesmix::AlgorithmState::ClusterState &>(state_); - } - - //! Down-casts the given generic proto message to a HierarchyHypers proto - bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( - google::protobuf::Message *state_) const { - return google::protobuf::internal::down_cast< - bayesmix::AlgorithmState::HierarchyHypers *>(state_); - } - - //! Down-casts the given generic proto message to a HierarchyHypers proto - const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( - const google::protobuf::Message &state_) const { - return google::protobuf::internal::down_cast< - const bayesmix::AlgorithmState::HierarchyHypers &>(state_); - } - - //! Container for state values - State state; - - //! Container for prior hyperparameters values - std::shared_ptr hypers; - - //! Container for posterior hyperparameters values - Hyperparams posterior_hypers; - - //! Pointer to a Protobuf prior object for this class - std::shared_ptr prior; - - //! Set of indexes of data points belonging to this cluster - std::set cluster_data_idx; - - //! Current cardinality of this cluster - int card = 0; - - //! Logarithm of current cardinality of this cluster - double log_card = stan::math::NEGATIVE_INFTY; -}; - -template -void BaseHierarchy::add_datum( - const int id, const Eigen::RowVectorXd &datum, - const bool update_params /*= false*/, - const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) { - assert(cluster_data_idx.find(id) == cluster_data_idx.end()); - card += 1; - log_card = std::log(card); - static_cast(this)->update_ss(datum, covariate, true); - cluster_data_idx.insert(id); - if (update_params) { - static_cast(this)->save_posterior_hypers(); - } -} - -template -void BaseHierarchy::remove_datum( - const int id, const Eigen::RowVectorXd &datum, - const bool update_params /*= false*/, - const Eigen::RowVectorXd &covariate /* = Eigen::RowVectorXd(0)*/) { - static_cast(this)->update_ss(datum, covariate, false); - set_card(card - 1); - auto it = cluster_data_idx.find(id); - assert(it != cluster_data_idx.end()); - cluster_data_idx.erase(it); - if (update_params) { - static_cast(this)->save_posterior_hypers(); - } -} - -template -void BaseHierarchy::write_state_to_proto( - google::protobuf::Message *out) const { - std::shared_ptr state_ = - get_state_proto(); - auto *out_cast = downcast_state(out); - out_cast->CopyFrom(*state_.get()); - out_cast->set_cardinality(card); -} - -template -void BaseHierarchy::write_hypers_to_proto( - google::protobuf::Message *out) const { - std::shared_ptr hypers_ = - get_hypers_proto(); - auto *out_cast = downcast_hypers(out); - out_cast->CopyFrom(*hypers_.get()); -} - -template -Eigen::VectorXd -BaseHierarchy::like_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->get_like_lpdf( - data.row(i), covariates.row(i)); - } - } - return lpdf; -} - -template -void BaseHierarchy::sample_full_cond( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) { - clear_data(); - clear_summary_statistics(); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - covariates.row(i)); - } - } - static_cast(this)->sample_full_cond(true); -} - -#endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/.old/conjugate_hierarchy.h b/src/hierarchies/.old/conjugate_hierarchy.h deleted file mode 100644 index 4d2430bea..000000000 --- a/src/hierarchies/.old/conjugate_hierarchy.h +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ - -#include "base_hierarchy.h" - -//! Template base class for conjugate hierarchy objects. - -//! This class acts as the base class for conjugate models, i.e. ones for which -//! both the prior and posterior distribution have the same form -//! (non-conjugate hierarchies should instead inherit directly from -//! `BaseHierarchy`). This also means that the marginal distribution for the -//! data is available in closed form. For this reason, each class deriving from -//! this one must have a free method with one of the following signatures, -//! based on whether it depends on covariates or not: -//! double marg_lpdf( -//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, -//! const Eigen::RowVectorXd &covariate) const; -//! or -//! double marg_lpdf( -//! const Hyperparams ¶ms, const Eigen::RowVectorXd &datum) const; -//! This returns the evaluation of the marginal distribution on the given data -//! point (and covariate, if any), conditioned on the provided `Hyperparams` -//! object. The latter may contain either prior or posterior values for -//! hyperparameters, depending on where this function is called within the -//! library. -//! For more information, please refer to parent classes `AbstractHierarchy` -//! and `BaseHierarchy`. - -template -class ConjugateHierarchy - : public BaseHierarchy { - public: - using BaseHierarchy::hypers; - using BaseHierarchy::posterior_hypers; - using BaseHierarchy::state; - - ConjugateHierarchy() = default; - ~ConjugateHierarchy() = default; - - //! Public wrapper for `marg_lpdf()` methods - virtual double get_marg_lpdf( - const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const; - - //! Evaluates the log-prior predictive distribution of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate (Optional) covariate vector associated to datum - //! @return The evaluation of the lpdf - double prior_pred_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = - Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(*hypers, datum, covariate); - } - - //! Evaluates the log-conditional predictive distr. of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate (Optional) covariate vector associated to datum - //! @return The evaluation of the lpdf - double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = - Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(posterior_hypers, datum, covariate); - } - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - virtual Eigen::VectorXd prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - virtual Eigen::VectorXd conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(const bool update_params = true) override { - if (this->card == 0) { - // No posterior update possible - static_cast(this)->sample_prior(); - } else { - Hyperparams params = - update_params - ? static_cast(this)->compute_posterior_hypers() - : posterior_hypers; - state = static_cast(this)->draw(params); - } - } - - //! Saves posterior hyperparameters to the corresponding class member - void save_posterior_hypers() { - posterior_hypers = - static_cast(this)->compute_posterior_hypers(); - } - - //! Returns whether the hierarchy represents a conjugate model or not - bool is_conjugate() const override { return true; } - - protected: - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - virtual double marg_lpdf(const Hyperparams ¶ms, - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - if (!this->is_dependent()) { - throw std::runtime_error( - "Cannot call marg_lpdf() from a non-dependent hierarchy"); - } else { - throw std::runtime_error("marg_lpdf() not implemented"); - } - } - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - virtual double marg_lpdf(const Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - if (this->is_dependent()) { - throw std::runtime_error( - "Cannot call marg_lpdf() from a dependent hierarchy"); - } else { - throw std::runtime_error("marg_lpdf() not implemented"); - } - } -}; - -template -double ConjugateHierarchy::get_marg_lpdf( - const Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { - if (this->is_dependent()) { - return marg_lpdf(params, datum, covariate); - } else { - return marg_lpdf(params, datum); - } -} - -template -Eigen::VectorXd -ConjugateHierarchy::prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->prior_pred_lpdf( - data.row(i), covariates.row(i)); - } - } - return lpdf; -} - -template -Eigen::VectorXd ConjugateHierarchy:: - conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const { - Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), covariates.row(i)); - } - } - return lpdf; -} - -#endif // BAYESMIX_HIERARCHIES_CONJUGATE_HIERARCHY_H_ diff --git a/src/hierarchies/.old/lapnig_hierarchy.cc b/src/hierarchies/.old/lapnig_hierarchy.cc deleted file mode 100644 index b0d479244..000000000 --- a/src/hierarchies/.old/lapnig_hierarchy.cc +++ /dev/null @@ -1,215 +0,0 @@ -#include "lapnig_hierarchy.h" - -#include - -#include -#include -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/rng.h" - -unsigned int LapNIGHierarchy::accepted_ = 0; -unsigned int LapNIGHierarchy::iter_ = 0; - -void LapNIGHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.scale = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -void LapNIGHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).lapnig_state(); - hypers->mean = hyperscast.mean(); - hypers->var = hyperscast.var(); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); - hypers->mh_mean_var = hyperscast.mh_mean_var(); - hypers->mh_log_scale_var = hyperscast.mh_log_scale_var(); -} - -double LapNIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return stan::math::double_exponential_lpdf(datum(0), state.mean, - state.scale); -} - -std::shared_ptr -LapNIGHierarchy::get_state_proto() const { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.scale); - - auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); - return out; -} - -std::shared_ptr -LapNIGHierarchy::get_hypers_proto() const { - bayesmix::LapNIGState hypers_; - hypers_.set_mean(hypers->mean); - hypers_.set_var(hypers->var); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - hypers_.set_mh_mean_var(hypers->mh_mean_var); - hypers_.set_mh_log_scale_var(hypers->mh_log_scale_var); - - auto out = std::make_shared(); - out->mutable_lapnig_state()->CopyFrom(hypers_); - return out; -} - -void LapNIGHierarchy::clear_summary_statistics() { - cluster_data_values.clear(); - sum_abs_diff_curr = 0; - sum_abs_diff_prop = 0; -} - -void LapNIGHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = prior->fixed_values().mean(); - hypers->var = prior->fixed_values().var(); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - hypers->mh_mean_var = prior->fixed_values().mh_mean_var(); - hypers->mh_log_scale_var = prior->fixed_values().mh_log_scale_var(); - // Check validity - if (hypers->var <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("Scale parameter must be > 0"); - } - if (hypers->mh_mean_var <= 0) { - throw std::invalid_argument("mh_mean_var parameter must be > 0"); - } - if (hypers->mh_log_scale_var <= 0) { - throw std::invalid_argument("mh_log_scale_var parameter must be > 0"); - } - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void LapNIGHierarchy::initialize_state() { - state.mean = hypers->mean; - state.scale = hypers->scale / (hypers->shape + 1); // mode of Inv-Gamma -} - -void LapNIGHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -LapNIG::State LapNIGHierarchy::draw(const LapNIG::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - LapNIG::State out{}; - out.scale = stan::math::inv_gamma_rng(params.shape, 1. / params.scale, rng); - out.mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); - return out; -} - -void LapNIGHierarchy::update_summary_statistics( - const Eigen::RowVectorXd &datum, const bool add) { - if (add) { - sum_abs_diff_curr += std::abs(state.mean - datum(0, 0)); - cluster_data_values.push_back(datum); - } else { - sum_abs_diff_curr -= std::abs(state.mean - datum(0, 0)); - auto it = std::find(cluster_data_values.begin(), cluster_data_values.end(), - datum); - cluster_data_values.erase(it); - } -} - -void LapNIGHierarchy::sample_full_cond(const bool update_params /*= false*/) { - if (this->card == 0) { - // No posterior update possible - this->sample_prior(); - } else { - // Number of iterations to compute the acceptance rate - ++iter_; - - // Random generator - auto &rng = bayesmix::Rng::Instance().get(); - - // Candidate mean and candidate log_scale - Eigen::VectorXd curr_unc_params(2); - curr_unc_params << state.mean, std::log(state.scale); - - Eigen::VectorXd prop_unc_params = propose_rwmh(curr_unc_params); - - double log_target_prop = - eval_prior_lpdf_unconstrained(prop_unc_params) + - eval_like_lpdf_unconstrained(prop_unc_params, false); - - double log_target_curr = - eval_prior_lpdf_unconstrained(curr_unc_params) + - eval_like_lpdf_unconstrained(curr_unc_params, true); - - double log_a_rate = log_target_prop - log_target_curr; - - if (std::log(stan::math::uniform_rng(0, 1, rng)) < log_a_rate) { - ++accepted_; - state.mean = prop_unc_params(0); - state.scale = std::exp(prop_unc_params(1)); - sum_abs_diff_curr = sum_abs_diff_prop; - } - } -} - -Eigen::VectorXd LapNIGHierarchy::propose_rwmh( - const Eigen::VectorXd &curr_vals) { - auto &rng = bayesmix::Rng::Instance().get(); - double candidate_mean = - curr_vals(0) + stan::math::normal_rng(0, sqrt(hypers->mh_mean_var), rng); - double candidate_log_scale = - curr_vals(1) + - stan::math::normal_rng(0, sqrt(hypers->mh_log_scale_var), rng); - Eigen::VectorXd proposal(2); - proposal << candidate_mean, candidate_log_scale; - return proposal; -} - -double LapNIGHierarchy::eval_prior_lpdf_unconstrained( - const Eigen::VectorXd &unconstrained_parameters) { - double mu = unconstrained_parameters(0); - double log_scale = unconstrained_parameters(1); - double scale = std::exp(log_scale); - return stan::math::normal_lpdf(mu, hypers->mean, std::sqrt(hypers->var)) + - stan::math::inv_gamma_lpdf(scale, hypers->shape, hypers->scale) + - log_scale; -} - -double LapNIGHierarchy::eval_like_lpdf_unconstrained( - const Eigen::VectorXd &unconstrained_parameters, const bool is_current) { - double mean = unconstrained_parameters(0); - double log_scale = unconstrained_parameters(1); - double scale = std::exp(log_scale); - double diff_sum = 0; // Sum of absolute values of data - candidate_mean - if (is_current) { - diff_sum = sum_abs_diff_curr; - } else { - for (auto &elem : cluster_data_values) { - diff_sum += std::abs(elem(0, 0) - mean); - } - sum_abs_diff_prop = diff_sum; - } - return std::log(0.5 / scale) + (-0.5 / scale * diff_sum); -} diff --git a/src/hierarchies/.old/lapnig_hierarchy.h b/src/hierarchies/.old/lapnig_hierarchy.h deleted file mode 100644 index cdb07e55b..000000000 --- a/src/hierarchies/.old/lapnig_hierarchy.h +++ /dev/null @@ -1,146 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ - -#include - -#include -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "base_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Laplace Normal-InverseGamma hierarchy for univariate data. - -//! This class represents a hierarchical model where data are distributed -//! according to a laplace likelihood, the parameters of which have a -//! Normal-InverseGamma centering distribution. That is: -//! f(x_i|mu,lambda) = Laplace(mu,lambda) -//! (mu,lambda) ~ N-IG(mu0, lambda0, alpha0, beta0) -//! The state is composed of mean and scale. The state hyperparameters, -//! contained in the Hypers object, are (mu_0, lambda0, alpha0, beta0, -//! scale_var, mean_var), all scalar values. Note that this hierarchy is NOT -//! conjugate, thus the marginal distribution is not available in closed form. -//! The hyperprameters scale_var and mean_var are used to perform a step of -//! Random Walk Metropolis Hastings to sample from the full conditionals. For -//! more information, please refer to parent classes: `AbstractHierarchy` and -//! `BaseHierarchy`. - -namespace LapNIG { -//! Custom container for State values -struct State { - double mean, scale; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double mean, var, shape, scale, mh_log_scale_var, mh_mean_var; -}; -} // namespace LapNIG - -class LapNIGHierarchy - : public BaseHierarchy { - public: - LapNIGHierarchy() = default; - ~LapNIGHierarchy() = default; - - //! Counters for tracking acceptance rate in MH step - static unsigned int accepted_; - static unsigned int iter_; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::LapNIG; - } - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - void save_posterior_hypers() { - throw std::runtime_error("save_posterior_hypers() not implemented"); - } - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - LapNIG::State draw(const LapNIG::Hyperparams ¶ms); - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(const bool update_params = false) override; - - protected: - //! Set of values of data points belonging to this cluster - std::list cluster_data_values; - - //! Sum of absolute differences for current params - double sum_abs_diff_curr = 0; - - //! Sum of absolute differences for proposal params - double sum_abs_diff_prop = 0; - - //! Samples from the proposal distribution using Random Walk - //! Metropolis-Hastings - Eigen::VectorXd propose_rwmh(const Eigen::VectorXd &curr_vals); - - //! Evaluates the prior given the mean (unconstrained_parameters(0)) - //! and log of the scale (unconstrained_parameters(1)) - double eval_prior_lpdf_unconstrained( - const Eigen::VectorXd &unconstrained_parameters); - - //! Evaluates the (sum of the) log likelihood for all the observations in the - //! cluster given the mean (unconstrained_parameters(0)) - //! and log of the scale (unconstrained_parameters(1)). - //! The parameter "is_current" is used to identify if the evaluation of the - //! likelihood is on the current or on the proposed parameters, in order to - //! avoid repeating calculations of the sum of the absolute differences - double eval_like_lpdf_unconstrained( - const Eigen::VectorXd &unconstrained_parameters, const bool is_current); - - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Initializes state parameters to appropriate values - void initialize_state() override; -}; -#endif // BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ diff --git a/src/hierarchies/.old/lin_reg_uni_hierarchy.cc b/src/hierarchies/.old/lin_reg_uni_hierarchy.cc deleted file mode 100644 index 89c469c96..000000000 --- a/src/hierarchies/.old/lin_reg_uni_hierarchy.cc +++ /dev/null @@ -1,166 +0,0 @@ -#include "lin_reg_uni_hierarchy.h" - -#include -#include -#include - -#include "src/utils/eigen_utils.h" -#include "src/utils/proto_utils.h" -#include "src/utils/rng.h" - -double LinRegUniHierarchy::like_lpdf( - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - return stan::math::normal_lpdf( - datum(0), state.regression_coeffs.dot(covariate), sqrt(state.var)); -} - -double LinRegUniHierarchy::marg_lpdf( - const LinRegUni::Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const { - double sig_n = sqrt( - (1 + (covariate * params.var_scaling_inv * covariate.transpose())(0)) * - params.scale / params.shape); - return stan::math::student_t_lpdf(datum(0), 2 * params.shape, - covariate.dot(params.mean), sig_n); -} - -void LinRegUniHierarchy::initialize_state() { - state.regression_coeffs = hypers->mean; - state.var = hypers->scale / (hypers->shape + 1); -} - -void LinRegUniHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); - dim = hypers->mean.size(); - hypers->var_scaling = - bayesmix::to_eigen(prior->fixed_values().var_scaling()); - hypers->var_scaling_inv = stan::math::inverse_spd(hypers->var_scaling); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - // Check validity - if (dim != hypers->var_scaling.rows()) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - bayesmix::check_spd(hypers->var_scaling); - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void LinRegUniHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -LinRegUni::State LinRegUniHierarchy::draw( - const LinRegUni::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - LinRegUni::State out; - out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - out.regression_coeffs = stan::math::multi_normal_prec_rng( - params.mean, params.var_scaling / out.var, rng); - return out; -} - -void LinRegUniHierarchy::update_summary_statistics( - const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate, - const bool add) { - if (add) { - data_sum_squares += datum(0) * datum(0); - covar_sum_squares += covariate.transpose() * covariate; - mixed_prod += datum(0) * covariate.transpose(); - } else { - data_sum_squares -= datum(0) * datum(0); - covar_sum_squares -= covariate.transpose() * covariate; - mixed_prod -= datum(0) * covariate.transpose(); - } -} - -void LinRegUniHierarchy::clear_summary_statistics() { - mixed_prod = Eigen::VectorXd::Zero(dim); - data_sum_squares = 0.0; - covar_sum_squares = Eigen::MatrixXd::Zero(dim, dim); -} - -LinRegUni::Hyperparams LinRegUniHierarchy::compute_posterior_hypers() const { - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - LinRegUni::Hyperparams post_params; - post_params.var_scaling = covar_sum_squares + hypers->var_scaling; - auto llt = post_params.var_scaling.llt(); - post_params.var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, dim)); - post_params.mean = - llt.solve(mixed_prod + hypers->var_scaling * hypers->mean); - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = - hypers->scale + - 0.5 * (data_sum_squares + - hypers->mean.transpose() * hypers->var_scaling * hypers->mean - - post_params.mean.transpose() * post_params.var_scaling * - post_params.mean); - return post_params; -} - -void LinRegUniHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.regression_coeffs = - bayesmix::to_eigen(statecast.lin_reg_uni_ls_state().regression_coeffs()); - state.var = statecast.lin_reg_uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -LinRegUniHierarchy::get_state_proto() const { - bayesmix::LinRegUniLSState state_; - bayesmix::to_proto(state.regression_coeffs, - state_.mutable_regression_coeffs()); - state_.set_var(state.var); - - auto out = std::make_shared(); - out->mutable_lin_reg_uni_ls_state()->CopyFrom(state_); - return out; -} - -void LinRegUniHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).lin_reg_uni_state(); - hypers->mean = bayesmix::to_eigen(hyperscast.mean()); - hypers->var_scaling = bayesmix::to_eigen(hyperscast.var_scaling()); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); -} - -std::shared_ptr -LinRegUniHierarchy::get_hypers_proto() const { - bayesmix::MultiNormalIGDistribution hypers_; - bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); - bayesmix::to_proto(hypers->var_scaling, hypers_.mutable_var_scaling()); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - - auto out = std::make_shared(); - out->mutable_lin_reg_uni_state()->CopyFrom(hypers_); - return out; -} diff --git a/src/hierarchies/.old/lin_reg_uni_hierarchy.h b/src/hierarchies/.old/lin_reg_uni_hierarchy.h deleted file mode 100644 index c55a2f7c5..000000000 --- a/src/hierarchies/.old/lin_reg_uni_hierarchy.h +++ /dev/null @@ -1,144 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Linear regression hierarchy for univariate data. - -//! This class implements a dependent hierarchy which represents the classical -//! univariate Bayesian linear regression model, i.e.: -//! y_i | \beta, x_i, \sigma^2 \sim N(\beta^T x_i, sigma^2) -//! \beta | \sigma^2 \sim N(\mu, sigma^2 Lambda^{-1}) -//! \sigma^2 \sim InvGamma(a, b) -//! -//! The state consists of the `regression_coeffs` \beta, and the `var` sigma^2. -//! Lambda is called the variance-scaling factor. For more information, please -//! refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace LinRegUni { -//! Custom container for State values -struct State { - Eigen::VectorXd regression_coeffs; - double var; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - Eigen::VectorXd mean; - Eigen::MatrixXd var_scaling; - Eigen::MatrixXd var_scaling_inv; - double shape; - double scale; -}; -} // namespace LinRegUni - -class LinRegUniHierarchy - : public ConjugateHierarchy { - public: - LinRegUniHierarchy() = default; - ~LinRegUniHierarchy() = default; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - LinRegUni::State draw(const LinRegUni::Hyperparams ¶ms); - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param covariate Covariate vector associated to datum - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate, - const bool add) override; - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::LinRegUni; - } - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Returns the dimension of the coefficients vector - unsigned int get_dim() const { return dim; } - - //! Computes and return posterior hypers given data currently in this cluster - LinRegUni::Hyperparams compute_posterior_hypers() const; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - //! Returns whether the hierarchy depends on covariate values or not - bool is_dependent() const override { return true; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf - double marg_lpdf(const LinRegUni::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate) const override; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Dimension of the coefficients vector - unsigned int dim; - - //! Represents pieces of y^t y - double data_sum_squares; - - //! Represents pieces of X^T X - Eigen::MatrixXd covar_sum_squares; - - //! Represents pieces of X^t y - Eigen::VectorXd mixed_prod; -}; - -#endif // BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ diff --git a/src/hierarchies/.old/nnig_hierarchy.cc b/src/hierarchies/.old/nnig_hierarchy.cc deleted file mode 100644 index ff1ab3870..000000000 --- a/src/hierarchies/.old/nnig_hierarchy.cc +++ /dev/null @@ -1,267 +0,0 @@ -#include "nnig_hierarchy.h" - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/rng.h" - -double NNIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); -} - -double NNIGHierarchy::marg_lpdf(const NNIG::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - double sig_n = sqrt(params.scale * (params.var_scaling + 1) / - (params.shape * params.var_scaling)); - return stan::math::student_t_lpdf(datum(0), 2 * params.shape, params.mean, - sig_n); -} - -void NNIGHierarchy::initialize_state() { - state.mean = hypers->mean; - state.var = hypers->scale / (hypers->shape + 1); -} - -void NNIGHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = prior->fixed_values().mean(); - hypers->var_scaling = prior->fixed_values().var_scaling(); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - // Check validity - if (hypers->var_scaling <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } - - else if (prior->has_normal_mean_prior()) { - // Set initial values - hypers->mean = prior->normal_mean_prior().mean_prior().mean(); - hypers->var_scaling = prior->normal_mean_prior().var_scaling(); - hypers->shape = prior->normal_mean_prior().shape(); - hypers->scale = prior->normal_mean_prior().scale(); - // Check validity - if (hypers->var_scaling <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } - - else if (prior->has_ngg_prior()) { - // Get hyperparameters: - // for mu0 - double mu00 = prior->ngg_prior().mean_prior().mean(); - double sigma00 = prior->ngg_prior().mean_prior().var(); - // for lambda0 - double alpha00 = prior->ngg_prior().var_scaling_prior().shape(); - double beta00 = prior->ngg_prior().var_scaling_prior().rate(); - // for beta0 - double a00 = prior->ngg_prior().scale_prior().shape(); - double b00 = prior->ngg_prior().scale_prior().rate(); - // for alpha0 - double alpha0 = prior->ngg_prior().shape(); - // Check validity - if (sigma00 <= 0) { - throw std::invalid_argument("Variance parameter must be > 0"); - } - if (alpha00 <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (beta00 <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - if (a00 <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (b00 <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - if (alpha0 <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - // Set initial values - hypers->mean = mu00; - hypers->var_scaling = alpha00 / beta00; - hypers->shape = alpha0; - hypers->scale = a00 / b00; - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void NNIGHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - - if (prior->has_fixed_values()) { - return; - } - - else if (prior->has_normal_mean_prior()) { - // Get hyperparameters - double mu00 = prior->normal_mean_prior().mean_prior().mean(); - double sig200 = prior->normal_mean_prior().mean_prior().var(); - double lambda0 = prior->normal_mean_prior().var_scaling(); - // Compute posterior hyperparameters - double prec = 0.0; - double num = 0.0; - for (auto &st : states) { - double mean = st.uni_ls_state().mean(); - double var = st.uni_ls_state().var(); - prec += 1 / var; - num += mean / var; - } - prec = 1 / sig200 + lambda0 * prec; - num = mu00 / sig200 + lambda0 * num; - double mu_n = num / prec; - double sig2_n = 1 / prec; - // Update hyperparameters with posterior random sampling - hypers->mean = stan::math::normal_rng(mu_n, sqrt(sig2_n), rng); - } - - else if (prior->has_ngg_prior()) { - // Get hyperparameters: - // for mu0 - double mu00 = prior->ngg_prior().mean_prior().mean(); - double sig200 = prior->ngg_prior().mean_prior().var(); - // for lambda0 - double alpha00 = prior->ngg_prior().var_scaling_prior().shape(); - double beta00 = prior->ngg_prior().var_scaling_prior().rate(); - // for tau0 - double a00 = prior->ngg_prior().scale_prior().shape(); - double b00 = prior->ngg_prior().scale_prior().rate(); - // Compute posterior hyperparameters - double b_n = 0.0; - double num = 0.0; - double beta_n = 0.0; - for (auto &st : states) { - double mean = st.uni_ls_state().mean(); - double var = st.uni_ls_state().var(); - b_n += 1 / var; - num += mean / var; - beta_n += (hypers->mean - mean) * (hypers->mean - mean) / var; - } - double var = hypers->var_scaling * b_n + 1 / sig200; - b_n += b00; - num = hypers->var_scaling * num + mu00 / sig200; - beta_n = beta00 + 0.5 * beta_n; - double sig_n = 1 / var; - double mu_n = num / var; - double alpha_n = alpha00 + 0.5 * states.size(); - double a_n = a00 + states.size() * hypers->shape; - // Update hyperparameters with posterior random Gibbs sampling - hypers->mean = stan::math::normal_rng(mu_n, sig_n, rng); - hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); - hypers->scale = stan::math::gamma_rng(a_n, b_n, rng); - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -NNIG::State NNIGHierarchy::draw(const NNIG::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - NNIG::State out; - out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - out.mean = stan::math::normal_rng(params.mean, - sqrt(state.var / params.var_scaling), rng); - return out; -} - -void NNIGHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) { - if (add) { - data_sum += datum(0); - data_sum_squares += datum(0) * datum(0); - } else { - data_sum -= datum(0); - data_sum_squares -= datum(0) * datum(0); - } -} - -void NNIGHierarchy::clear_summary_statistics() { - data_sum = 0; - data_sum_squares = 0; -} - -NNIG::Hyperparams NNIGHierarchy::compute_posterior_hypers() const { - // Initialize relevant variables - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - NNIG::Hyperparams post_params; - double y_bar = data_sum / (1.0 * card); // sample mean - double ss = data_sum_squares - card * y_bar * y_bar; - post_params.mean = (hypers->var_scaling * hypers->mean + data_sum) / - (hypers->var_scaling + card); - post_params.var_scaling = hypers->var_scaling + card; - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = hypers->scale + 0.5 * ss + - 0.5 * hypers->var_scaling * card * - (y_bar - hypers->mean) * (y_bar - hypers->mean) / - (card + hypers->var_scaling); - return post_params; -} - -void NNIGHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.var = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -NNIGHierarchy::get_state_proto() const { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.var); - - auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); - return out; -} - -void NNIGHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).nnig_state(); - hypers->mean = hyperscast.mean(); - hypers->var_scaling = hyperscast.var_scaling(); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); -} - -std::shared_ptr -NNIGHierarchy::get_hypers_proto() const { - bayesmix::NIGDistribution hypers_; - hypers_.set_mean(hypers->mean); - hypers_.set_var_scaling(hypers->var_scaling); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - - auto out = std::make_shared(); - out->mutable_nnig_state()->CopyFrom(hypers_); - return out; -} diff --git a/src/hierarchies/.old/nnig_hierarchy.h b/src/hierarchies/.old/nnig_hierarchy.h deleted file mode 100644 index 7911691e9..000000000 --- a/src/hierarchies/.old/nnig_hierarchy.h +++ /dev/null @@ -1,122 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Conjugate Normal Normal-InverseGamma hierarchy for univariate data. - -//! This class represents a hierarchical model where data are distributed -//! according to a normal likelihood, the parameters of which have a -//! Normal-InverseGamma centering distribution. That is: -//! f(x_i|mu,sig) = N(mu,sig^2) -//! (mu,sig^2) ~ N-IG(mu0, lambda0, alpha0, beta0) -//! The state is composed of mean and variance. The state hyperparameters, -//! contained in the Hypers object, are (mu_0, lambda0, alpha0, beta0), all -//! scalar values. Note that this hierarchy is conjugate, thus the marginal -//! distribution is available in closed form. For more information, please -//! refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace NNIG { -//! Custom container for State values -struct State { - double mean, var; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double mean, var_scaling, shape, scale; -}; - -}; // namespace NNIG - -class NNIGHierarchy - : public ConjugateHierarchy { - public: - NNIGHierarchy() = default; - ~NNIGHierarchy() = default; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - NNIG::State draw(const NNIG::Hyperparams ¶ms); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::NNIG; - } - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Computes and return posterior hypers given data currently in this cluster - NNIG::Hyperparams compute_posterior_hypers() const; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double marg_lpdf(const NNIG::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) override; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Sum of data points currently belonging to the cluster - double data_sum = 0; - - //! Sum of squared data points currently belonging to the cluster - double data_sum_squares = 0; -}; - -#endif // BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ diff --git a/src/hierarchies/.old/nnw_hierarchy.cc b/src/hierarchies/.old/nnw_hierarchy.cc deleted file mode 100644 index 65038e946..000000000 --- a/src/hierarchies/.old/nnw_hierarchy.cc +++ /dev/null @@ -1,373 +0,0 @@ -#include "nnw_hierarchy.h" - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "matrix.pb.h" -#include "src/utils/distributions.h" -#include "src/utils/eigen_utils.h" -#include "src/utils/proto_utils.h" -#include "src/utils/rng.h" - -double NNWHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return bayesmix::multi_normal_prec_lpdf(datum, state.mean, state.prec_chol, - state.prec_logdet); -} - -double NNWHierarchy::marg_lpdf(const NNW::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const { - NNW::Hyperparams pred_params = get_predictive_t_parameters(params); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf( - datum, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} - -Eigen::VectorXd NNWHierarchy::like_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - return bayesmix::multi_normal_prec_lpdf_grid( - data, state.mean, state.prec_chol, state.prec_logdet); -} - -Eigen::VectorXd NNWHierarchy::prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - NNW::Hyperparams pred_params = get_predictive_t_parameters(*hypers); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf_grid( - data, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} - -Eigen::VectorXd NNWHierarchy::conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { - // Custom, optimized grid method - NNW::Hyperparams pred_params = - get_predictive_t_parameters(compute_posterior_hypers()); - Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); - double logdet = 2 * log(diag.array()).sum(); - - return bayesmix::multi_student_t_invscale_lpdf_grid( - data, pred_params.deg_free, pred_params.mean, pred_params.scale_chol, - logdet); -} - -void NNWHierarchy::initialize_state() { - state.mean = hypers->mean; - write_prec_to_state( - hypers->var_scaling * Eigen::MatrixXd::Identity(dim, dim), &state); -} - -void NNWHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = bayesmix::to_eigen(prior->fixed_values().mean()); - dim = hypers->mean.size(); - hypers->var_scaling = prior->fixed_values().var_scaling(); - hypers->scale = bayesmix::to_eigen(prior->fixed_values().scale()); - hypers->deg_free = prior->fixed_values().deg_free(); - // Check validity - if (hypers->var_scaling <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - if (dim != hypers->scale.rows()) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - if (hypers->deg_free <= dim - 1) { - throw std::invalid_argument("Degrees of freedom parameter is not valid"); - } - } - - else if (prior->has_normal_mean_prior()) { - // Get hyperparameters - Eigen::VectorXd mu00 = - bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().mean()); - dim = mu00.size(); - Eigen::MatrixXd sigma00 = - bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().var()); - double lambda0 = prior->normal_mean_prior().var_scaling(); - Eigen::MatrixXd tau0 = - bayesmix::to_eigen(prior->normal_mean_prior().scale()); - double nu0 = prior->normal_mean_prior().deg_free(); - // Check validity - unsigned int dim = mu00.size(); - if (sigma00.rows() != dim or tau0.rows() != dim) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - bayesmix::check_spd(sigma00); - if (lambda0 <= 0) { - throw std::invalid_argument("Variance-scaling parameter must be > 0"); - } - bayesmix::check_spd(tau0); - if (nu0 <= dim - 1) { - throw std::invalid_argument("Degrees of freedom parameter is not valid"); - } - // Set initial values - hypers->mean = mu00; - hypers->var_scaling = lambda0; - hypers->scale = tau0; - hypers->deg_free = nu0; - } - - else if (prior->has_ngiw_prior()) { - // Get hyperparameters: - // for mu0 - Eigen::VectorXd mu00 = - bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); - dim = mu00.size(); - Eigen::MatrixXd sigma00 = - bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); - // for lambda0 - double alpha00 = prior->ngiw_prior().var_scaling_prior().shape(); - double beta00 = prior->ngiw_prior().var_scaling_prior().rate(); - // for tau0 - double nu00 = prior->ngiw_prior().scale_prior().deg_free(); - Eigen::MatrixXd tau00 = - bayesmix::to_eigen(prior->ngiw_prior().scale_prior().scale()); - // for nu0 - double nu0 = prior->ngiw_prior().deg_free(); - // Check validity: - // dimensionality - if (sigma00.rows() != dim or tau00.rows() != dim) { - throw std::invalid_argument( - "Hyperparameters dimensions are not consistent"); - } - // for mu0 - bayesmix::check_spd(sigma00); - // for lambda0 - if (alpha00 <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (beta00 <= 0) { - throw std::invalid_argument("Rate parameter must be > 0"); - } - // for tau0 - if (nu00 <= 0) { - throw std::invalid_argument("Degrees of freedom parameter must be > 0"); - } - bayesmix::check_spd(tau00); - // check nu0 - if (nu0 <= dim - 1) { - throw std::invalid_argument("Degrees of freedom parameter is not valid"); - } - // Set initial values - hypers->mean = mu00; - hypers->var_scaling = alpha00 / beta00; - hypers->scale = tau00 / (nu00 + dim + 1); - hypers->deg_free = nu0; - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } - hypers->scale_inv = stan::math::inverse_spd(hypers->scale); - hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); -} - -void NNWHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } - - else if (prior->has_normal_mean_prior()) { - // Get hyperparameters - Eigen::VectorXd mu00 = - bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().mean()); - Eigen::MatrixXd sigma00 = - bayesmix::to_eigen(prior->normal_mean_prior().mean_prior().var()); - double lambda0 = prior->normal_mean_prior().var_scaling(); - // Compute posterior hyperparameters - Eigen::MatrixXd sigma00inv = stan::math::inverse_spd(sigma00); - Eigen::MatrixXd prec = Eigen::MatrixXd::Zero(dim, dim); - Eigen::VectorXd num = Eigen::MatrixXd::Zero(dim, 1); - for (auto &st : states) { - Eigen::MatrixXd prec_i = bayesmix::to_eigen(st.multi_ls_state().prec()); - prec += prec_i; - num += prec_i * bayesmix::to_eigen(st.multi_ls_state().mean()); - } - prec = hypers->var_scaling * prec + sigma00inv; - num = hypers->var_scaling * num + sigma00inv * mu00; - Eigen::VectorXd mu_n = prec.llt().solve(num); - // Update hyperparameters with posterior sampling - hypers->mean = stan::math::multi_normal_prec_rng(mu_n, prec, rng); - } - - else if (prior->has_ngiw_prior()) { - // Get hyperparameters: - // for mu0 - Eigen::VectorXd mu00 = - bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); - Eigen::MatrixXd sigma00 = - bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); - // for lambda0 - double alpha00 = prior->ngiw_prior().var_scaling_prior().shape(); - double beta00 = prior->ngiw_prior().var_scaling_prior().rate(); - // for tau0 - double nu00 = prior->ngiw_prior().scale_prior().deg_free(); - Eigen::MatrixXd tau00 = - bayesmix::to_eigen(prior->ngiw_prior().scale_prior().scale()); - // Compute posterior hyperparameters - Eigen::MatrixXd sigma00inv = stan::math::inverse_spd(sigma00); - Eigen::MatrixXd tau_n = Eigen::MatrixXd::Zero(dim, dim); - Eigen::VectorXd num = Eigen::MatrixXd::Zero(dim, 1); - double beta_n = 0.0; - for (auto &st : states) { - Eigen::VectorXd mean = bayesmix::to_eigen(st.multi_ls_state().mean()); - Eigen::MatrixXd prec = bayesmix::to_eigen(st.multi_ls_state().prec()); - tau_n += prec; - num += prec * mean; - beta_n += - (hypers->mean - mean).transpose() * prec * (hypers->mean - mean); - } - Eigen::MatrixXd prec_n = hypers->var_scaling * tau_n + sigma00inv; - tau_n += tau00; - num = hypers->var_scaling * num + sigma00inv * mu00; - beta_n = beta00 + 0.5 * beta_n; - Eigen::MatrixXd sig_n = stan::math::inverse_spd(prec_n); - Eigen::VectorXd mu_n = sig_n * num; - double alpha_n = alpha00 + 0.5 * states.size(); - double nu_n = nu00 + states.size() * hypers->deg_free; - // Update hyperparameters with posterior random Gibbs sampling - hypers->mean = stan::math::multi_normal_rng(mu_n, sig_n, rng); - hypers->var_scaling = stan::math::gamma_rng(alpha_n, beta_n, rng); - hypers->scale = stan::math::inv_wishart_rng(nu_n, tau_n, rng); - hypers->scale_inv = stan::math::inverse_spd(hypers->scale); - hypers->scale_chol = Eigen::LLT(hypers->scale).matrixU(); - } - - else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -NNW::State NNWHierarchy::draw(const NNW::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - Eigen::MatrixXd tau_new = - stan::math::wishart_rng(params.deg_free, params.scale, rng); - // Update state - NNW::State out; - out.mean = stan::math::multi_normal_prec_rng( - params.mean, tau_new * params.var_scaling, rng); - write_prec_to_state(tau_new, &out); - return out; -} - -void NNWHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) { - if (add) { - data_sum += datum.transpose(); - data_sum_squares += datum.transpose() * datum; - } else { - data_sum -= datum.transpose(); - data_sum_squares -= datum.transpose() * datum; - } -} - -void NNWHierarchy::clear_summary_statistics() { - data_sum = Eigen::VectorXd::Zero(dim); - data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); -} - -NNW::Hyperparams NNWHierarchy::compute_posterior_hypers() const { - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - NNW::Hyperparams post_params; - post_params.var_scaling = hypers->var_scaling + card; - post_params.deg_free = hypers->deg_free + card; - Eigen::VectorXd mubar = data_sum.array() / card; // sample mean - post_params.mean = (hypers->var_scaling * hypers->mean + card * mubar) / - (hypers->var_scaling + card); - // Compute tau_n - Eigen::MatrixXd tau_temp = - data_sum_squares - card * mubar * mubar.transpose(); - tau_temp += (card * hypers->var_scaling / (card + hypers->var_scaling)) * - (mubar - hypers->mean) * (mubar - hypers->mean).transpose(); - post_params.scale_inv = tau_temp + hypers->scale_inv; - post_params.scale = stan::math::inverse_spd(post_params.scale_inv); - post_params.scale_chol = - Eigen::LLT(post_params.scale).matrixU(); - return post_params; -} - -void NNWHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = to_eigen(statecast.multi_ls_state().mean()); - state.prec = to_eigen(statecast.multi_ls_state().prec()); - state.prec_chol = to_eigen(statecast.multi_ls_state().prec_chol()); - Eigen::VectorXd diag = state.prec_chol.diagonal(); - state.prec_logdet = 2 * log(diag.array()).sum(); - set_card(statecast.cardinality()); -} - -std::shared_ptr -NNWHierarchy::get_state_proto() const { - bayesmix::MultiLSState state_; - bayesmix::to_proto(state.mean, state_.mutable_mean()); - bayesmix::to_proto(state.prec, state_.mutable_prec()); - bayesmix::to_proto(state.prec_chol, state_.mutable_prec_chol()); - - auto out = std::make_shared(); - out->mutable_multi_ls_state()->CopyFrom(state_); - return out; -} - -void NNWHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).nnw_state(); - hypers->mean = to_eigen(hyperscast.mean()); - hypers->var_scaling = hyperscast.var_scaling(); - hypers->deg_free = hyperscast.deg_free(); - hypers->scale = to_eigen(hyperscast.scale()); -} - -std::shared_ptr -NNWHierarchy::get_hypers_proto() const { - bayesmix::NWDistribution hypers_; - bayesmix::to_proto(hypers->mean, hypers_.mutable_mean()); - hypers_.set_var_scaling(hypers->var_scaling); - hypers_.set_deg_free(hypers->deg_free); - bayesmix::to_proto(hypers->scale, hypers_.mutable_scale()); - - auto out = std::make_shared(); - out->mutable_nnw_state()->CopyFrom(hypers_); - return out; -} - -void NNWHierarchy::write_prec_to_state(const Eigen::MatrixXd &prec_, - NNW::State *out) { - out->prec = prec_; - // Update prec utilities - out->prec_chol = Eigen::LLT(prec_).matrixU(); - Eigen::VectorXd diag = out->prec_chol.diagonal(); - out->prec_logdet = 2 * log(diag.array()).sum(); -} - -NNW::Hyperparams NNWHierarchy::get_predictive_t_parameters( - const NNW::Hyperparams ¶ms) const { - // Compute dof and scale of marginal distribution - double nu_n = params.deg_free - dim + 1; - double coeff = (params.var_scaling + 1) / (params.var_scaling * nu_n); - Eigen::MatrixXd scale_chol_n = params.scale_chol / std::sqrt(coeff); - - NNW::Hyperparams out; - out.mean = params.mean; - out.deg_free = nu_n; - out.scale_chol = scale_chol_n; - return out; -} diff --git a/src/hierarchies/.old/nnw_hierarchy.h b/src/hierarchies/.old/nnw_hierarchy.h deleted file mode 100644 index 1b149d422..000000000 --- a/src/hierarchies/.old/nnw_hierarchy.h +++ /dev/null @@ -1,168 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Normal Normal-Wishart hierarchy for multivariate data. - -//! This class represents a hierarchy, i.e. a cluster, whose multivariate data -//! are distributed according to a multinomial normal likelihood, the -//! parameters of which have a Normal-Wishart centering distribution. That is: -//! f(x_i|mu,tau) = N(mu,tau^{-1}) -//! (mu,tau) ~ NW(mu0, lambda0, tau0, nu0) -//! The state is composed of mean and precision matrix. The Cholesky factor and -//! log-determinant of the latter are also included in the container for -//! efficiency reasons. The state's hyperparameters, contained in the Hypers -//! object, are (mu0, lambda0, tau0, nu0), which are respectively vector, -//! scalar, matrix, and scalar. Note that this hierarchy is conjugate, thus the -//! marginal distribution is available in closed form. For more information, -//! please refer to parent classes: `AbstractHierarchy`, `BaseHierarchy`, and -//! `ConjugateHierarchy`. - -namespace NNW { -//! Custom container for State values -struct State { - Eigen::VectorXd mean; - Eigen::MatrixXd prec; - Eigen::MatrixXd prec_chol; - double prec_logdet; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - Eigen::VectorXd mean; - double var_scaling; - double deg_free; - Eigen::MatrixXd scale; - Eigen::MatrixXd scale_inv; - Eigen::MatrixXd scale_chol; -}; -} // namespace NNW - -class NNWHierarchy - : public ConjugateHierarchy { - public: - NNWHierarchy() = default; - ~NNWHierarchy() = default; - - // EVALUATION FUNCTIONS FOR GRIDS OF POINTS - //! Evaluates the log-likelihood of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd like_lpdf_grid(const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = - Eigen::MatrixXd(0, 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd prior_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Evaluates the log-prior predictive distr. of data in a grid of points - //! @param data Grid of points (by row) which are to be evaluated - //! @param covariates (Optional) covariate vectors associated to data - //! @return The evaluation of the lpdf - Eigen::VectorXd conditional_pred_lpdf_grid( - const Eigen::MatrixXd &data, - const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, - 0)) const override; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - NNW::State draw(const NNW::Hyperparams ¶ms); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::NNW; - } - - //! Computes and return posterior hypers given data currently in this cluster - NNW::Hyperparams compute_posterior_hypers() const; - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return true; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double marg_lpdf(const NNW::Hyperparams ¶ms, - const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) override; - - //! Writes prec and its utilities to the given state object by pointer - void write_prec_to_state(const Eigen::MatrixXd &prec_, NNW::State *out); - - //! Returns parameters for the predictive Student's t distribution - NNW::Hyperparams get_predictive_t_parameters( - const NNW::Hyperparams ¶ms) const; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Dimension of data space - unsigned int dim; - - //! Sum of data points currently belonging to the cluster - Eigen::VectorXd data_sum; - - //! Sum of squared data points currently belonging to the cluster - Eigen::MatrixXd data_sum_squares; -}; - -#endif // BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ diff --git a/src/hierarchies/.old/nnxig_hierarchy.cc b/src/hierarchies/.old/nnxig_hierarchy.cc deleted file mode 100644 index f7bc62a16..000000000 --- a/src/hierarchies/.old/nnxig_hierarchy.cc +++ /dev/null @@ -1,152 +0,0 @@ -#include "nnxig_hierarchy.h" - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "hierarchy_prior.pb.h" -#include "ls_state.pb.h" -#include "src/utils/rng.h" - -double NNxIGHierarchy::like_lpdf(const Eigen::RowVectorXd &datum) const { - return stan::math::normal_lpdf(datum(0), state.mean, sqrt(state.var)); -} - -void NNxIGHierarchy::initialize_state() { - state.mean = hypers->mean; - state.var = hypers->scale / (hypers->shape + 1); -} - -void NNxIGHierarchy::initialize_hypers() { - if (prior->has_fixed_values()) { - // Set values - hypers->mean = prior->fixed_values().mean(); - hypers->var = prior->fixed_values().var(); - hypers->shape = prior->fixed_values().shape(); - hypers->scale = prior->fixed_values().scale(); - - // Check validity - if (hypers->var <= 0) { - throw std::invalid_argument("Variance parameter must be > 0"); - } - if (hypers->shape <= 0) { - throw std::invalid_argument("Shape parameter must be > 0"); - } - if (hypers->scale <= 0) { - throw std::invalid_argument("scale parameter must be > 0"); - } - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void NNxIGHierarchy::update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) { - if (add) { - data_sum += datum(0); - data_sum_squares += datum(0) * datum(0); - } else { - data_sum -= datum(0); - data_sum_squares -= datum(0) * datum(0); - } -} - -void NNxIGHierarchy::update_hypers( - const std::vector &states) { - auto &rng = bayesmix::Rng::Instance().get(); - if (prior->has_fixed_values()) { - return; - } else { - throw std::invalid_argument("Unrecognized hierarchy prior"); - } -} - -void NNxIGHierarchy::clear_summary_statistics() { - data_sum = 0; - data_sum_squares = 0; -} - -void NNxIGHierarchy::set_state_from_proto( - const google::protobuf::Message &state_) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.var = statecast.uni_ls_state().var(); - set_card(statecast.cardinality()); -} - -void NNxIGHierarchy::set_hypers_from_proto( - const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_).nnxig_state(); - hypers->mean = hyperscast.mean(); - hypers->var = hyperscast.var(); - hypers->scale = hyperscast.scale(); - hypers->shape = hyperscast.shape(); -} - -std::shared_ptr -NNxIGHierarchy::get_state_proto() const { - bayesmix::UniLSState state_; - state_.set_mean(state.mean); - state_.set_var(state.var); - - auto out = std::make_shared(); - out->mutable_uni_ls_state()->CopyFrom(state_); - return out; -} - -std::shared_ptr -NNxIGHierarchy::get_hypers_proto() const { - bayesmix::NxIGDistribution hypers_; - hypers_.set_mean(hypers->mean); - hypers_.set_var(hypers->var); - hypers_.set_shape(hypers->shape); - hypers_.set_scale(hypers->scale); - - auto out = std::make_shared(); - out->mutable_nnxig_state()->CopyFrom(hypers_); - return out; -} - -void NNxIGHierarchy::sample_full_cond(bool update_params) { - if (this->card == 0) { - // No posterior update possible - sample_prior(); - } else { - NNxIG::Hyperparams params = - update_params ? compute_posterior_hypers() : posterior_hypers; - state = draw(params); - } -} - -NNxIG::State NNxIGHierarchy::draw(const NNxIG::Hyperparams ¶ms) { - auto &rng = bayesmix::Rng::Instance().get(); - NNxIG::State out; - out.var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - out.mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); - return out; -} - -NNxIG::Hyperparams NNxIGHierarchy::compute_posterior_hypers() const { - // Initialize relevant variables - if (card == 0) { // no update possible - return *hypers; - } - // Compute posterior hyperparameters - NNxIG::Hyperparams post_params; - double var_y = data_sum_squares - 2 * state.mean * data_sum + - card * state.mean * state.mean; - post_params.mean = (hypers->var * data_sum + state.var * hypers->mean) / - (card * hypers->var + state.var); - post_params.var = - (state.var * hypers->var) / (card * hypers->var + state.var); - post_params.shape = hypers->shape + 0.5 * card; - post_params.scale = hypers->scale + 0.5 * var_y; - return post_params; -} - -void NNxIGHierarchy::save_posterior_hypers() { - posterior_hypers = compute_posterior_hypers(); -} diff --git a/src/hierarchies/.old/nnxig_hierarchy.h b/src/hierarchies/.old/nnxig_hierarchy.h deleted file mode 100644 index de5e878f6..000000000 --- a/src/hierarchies/.old/nnxig_hierarchy.h +++ /dev/null @@ -1,120 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ - -#include - -#include -#include -#include - -#include "algorithm_state.pb.h" -#include "base_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "hierarchy_prior.pb.h" - -//! Non Conjugate Normal Normal-InverseGamma hierarchy for univariate data. - -//! This class represents a hierarchical model where data are distributed -//! according to a normal likelihood, the parameters of which have a -//! Normal-InverseGamma centering distribution. That is: -//! f(x_i|mu,sig) = N(mu,sig^2) -//! mu ~ N(mu0, sigma0) -//! sig^2 ~ IG(alpha0, beta0) -//! The state is composed of mean and variance. The state hyperparameters, -//! contained in the Hypers object, are (mu0, sigma0, alpha0, beta0), all -//! scalar values. Note that this hierarchy is non conjugate. - -namespace NNxIG { -//! Custom container for State values -struct State { - double mean, var; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double mean, var, shape, scale; -}; - -}; // namespace NNxIG - -class NNxIGHierarchy - : public BaseHierarchy { - public: - NNxIGHierarchy() = default; - ~NNxIGHierarchy() = default; - - //! Updates hyperparameter values given a vector of cluster states - void update_hypers(const std::vector - &states) override; - - //! Updates state values using the given (prior or posterior) hyperparameters - NNxIG::State draw(const NNxIG::Hyperparams ¶ms); - - //! Generates new state values from the centering posterior distribution - //! @param update_params Save posterior hypers after the computation? - void sample_full_cond(bool update_params = true) override; - - //! Saves posterior hyperparameters to the corresponding class member - void save_posterior_hypers(); - - //! Resets summary statistics for this cluster - void clear_summary_statistics() override; - - //! Returns the Protobuf ID associated to this class - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::NNxIG; - } - - //! Read and set state values from a given Protobuf message - void set_state_from_proto(const google::protobuf::Message &state_) override; - - //! Read and set hyperparameter values from a given Protobuf message - void set_hypers_from_proto( - const google::protobuf::Message &hypers_) override; - - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - std::shared_ptr get_state_proto() - const override; - - //! Writes current value of hyperparameters to a Protobuf message and - //! return a shared_ptr. - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::HierarchyHypers message by adding the appropriate type - std::shared_ptr get_hypers_proto() - const override; - - //! Computes and return posterior hypers given data currently in this cluster - NNxIG::Hyperparams compute_posterior_hypers() const; - - //! Returns whether the hierarchy models multivariate data or not - bool is_multivariate() const override { return false; } - - protected: - //! Evaluates the log-likelihood of data in a single point - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf - double like_lpdf(const Eigen::RowVectorXd &datum) const override; - - //! Updates cluster statistics when a datum is added or removed from it - //! @param datum Data point which is being added or removed - //! @param add Whether the datum is being added or removed - void update_summary_statistics(const Eigen::RowVectorXd &datum, - bool add) override; - - //! Initializes state parameters to appropriate values - void initialize_state() override; - - //! Initializes hierarchy hyperparameters to appropriate values - void initialize_hypers() override; - - //! Sum of data points currently belonging to the cluster - double data_sum = 0; - - //! Sum of squared data points currently belonging to the cluster - double data_sum_squares = 0; -}; - -#endif // BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ From f98315818a8e6c5fe1c1e24a7bef6abe862787d5 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:54:28 +0100 Subject: [PATCH 214/317] Ignore .old directories --- src/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/src/.gitignore b/src/.gitignore index 1b5d07b0b..0c547d223 100644 --- a/src/.gitignore +++ b/src/.gitignore @@ -1 +1,2 @@ hierarchies/.old/ +hierarchies/likelihoods/states/.old/ From c59a849e9b1a66f7a5264fd35acdf31fac7b797e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:55:01 +0100 Subject: [PATCH 215/317] Delete src/hierarchies/likelihoods/states/.old directory --- .../likelihoods/states/.old/states.h | 132 ------------------ 1 file changed, 132 deletions(-) delete mode 100644 src/hierarchies/likelihoods/states/.old/states.h diff --git a/src/hierarchies/likelihoods/states/.old/states.h b/src/hierarchies/likelihoods/states/.old/states.h deleted file mode 100644 index 60df73e4d..000000000 --- a/src/hierarchies/likelihoods/states/.old/states.h +++ /dev/null @@ -1,132 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_H_ -#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_H_ - -#include -#include - -#include "algorithm_state.pb.h" -#include "src/utils/proto_utils.h" - -namespace State { - -template -Eigen::Matrix uni_ls_to_constrained( - Eigen::Matrix in) { - Eigen::Matrix out(2); - out << in(0), stan::math::exp(in(1)); - return out; -} - -template -Eigen::Matrix uni_ls_to_unconstrained( - Eigen::Matrix in) { - Eigen::Matrix out(2); - out << in(0), stan::math::log(in(1)); - return out; -} - -template -T uni_ls_log_det_jac(Eigen::Matrix constrained) { - T out = 0; - stan::math::positive_constrain(stan::math::log(constrained(1)), out); - return out; -} - -class UniLS { - public: - double mean, var; - - Eigen::VectorXd get_unconstrained() { - Eigen::VectorXd temp(2); - temp << mean, var; - return uni_ls_to_unconstrained(temp); - } - - void set_from_unconstrained(Eigen::VectorXd in) { - Eigen::VectorXd temp = uni_ls_to_constrained(in); - mean = temp(0); - var = temp(1); - } - - void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { - mean = state_.uni_ls_state().mean(); - var = state_.uni_ls_state().var(); - } - - bayesmix::AlgorithmState::ClusterState get_as_proto() { - bayesmix::AlgorithmState::ClusterState state; - state.mutable_uni_ls_state()->set_mean(mean); - state.mutable_uni_ls_state()->set_var(var); - return state; - } - - double log_det_jac() { - Eigen::VectorXd temp(2); - temp << mean, var; - return uni_ls_log_det_jac(temp); - } -}; - -class MultiLS { - public: - Eigen::VectorXd mean; - Eigen::MatrixXd prec, prec_chol; - double prec_logdet; - - Eigen::VectorXd get_unconstrained() { - Eigen::VectorXd out_prec = stan::math::cov_matrix_free(prec); - Eigen::VectorXd out(mean.size() + out_prec.size()); - out << mean, out_prec; - return out; - } - - void set_from_constrained(Eigen::VectorXd mean_, Eigen::MatrixXd prec_) { - mean = mean_; - prec = prec_; - prec_chol = Eigen::LLT(prec).matrixL(); - Eigen::VectorXd diag = prec_chol.diagonal(); - prec_logdet = 2 * log(diag.array()).sum(); - } - - void set_from_unconstrained(Eigen::VectorXd in) { - double dim_ = 0.5 * (std::sqrt(8 * in.size() + 9) - 3); - double dim; - assert(modf(dim_, &dim) == 0.0); - mean = in.head(int(dim)); - prec = - stan::math::cov_matrix_constrain(in.tail(int(in.size() - dim)), dim); - set_from_constrained(mean, prec); - } - - void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { - mean = to_eigen(state_.multi_ls_state().mean()); - prec = to_eigen(state_.multi_ls_state().prec()); - prec_chol = to_eigen(state_.multi_ls_state().prec_chol()); - Eigen::VectorXd diag = prec_chol.diagonal(); - prec_logdet = 2 * log(diag.array()).sum(); - } - - bayesmix::AlgorithmState::ClusterState get_as_proto() { - bayesmix::AlgorithmState::ClusterState state; - bayesmix::to_proto(mean, state.mutable_multi_ls_state()->mutable_mean()); - bayesmix::to_proto(prec, state.mutable_multi_ls_state()->mutable_prec()); - bayesmix::to_proto(prec_chol, - state.mutable_multi_ls_state()->mutable_prec_chol()); - return state; - } - - double log_det_jac() { - double out = 0; - stan::math::positive_constrain(stan::math::cov_matrix_free(prec), out); - return out; - } -}; - -struct UniLinReg { - Eigen::VectorXd regression_coeffs; - double var; -}; - -} // namespace State - -#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_H_ From b33bba167f43b3c99ec298f65486744e1526a9a9 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 21 Mar 2022 16:58:44 +0100 Subject: [PATCH 216/317] Delete src/.gitignore --- src/.gitignore | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 src/.gitignore diff --git a/src/.gitignore b/src/.gitignore deleted file mode 100644 index 0c547d223..000000000 --- a/src/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -hierarchies/.old/ -hierarchies/likelihoods/states/.old/ From e8a19d51c67d94fc0737f51bdcb9e678cef5c1c7 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 22 Mar 2022 14:32:08 +0100 Subject: [PATCH 217/317] Move set_likelihood and set_prior to BaseHIerarchy --- src/hierarchies/abstract_hierarchy.h | 15 ++++++++------- src/hierarchies/base_hierarchy.h | 10 +++++----- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 692b80b71..7e141ae68 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -51,13 +51,14 @@ class AbstractHierarchy { public: - //! Set the likelihood for the current hierarchy. Implemented in the - //! BaseHierarchy class - virtual void set_likelihood(std::shared_ptr like_) = 0; - - //! Set the prior model for the current hierarchy. Implemented in the - //! BaseHierarchy class - virtual void set_prior(std::shared_ptr prior_) = 0; + // Set the likelihood for the current hierarchy. Implemented in the + // BaseHierarchy class + // virtual void set_likelihood(std::shared_ptr like_) = + // 0; + + // Set the prior model for the current hierarchy. Implemented in the + // BaseHierarchy class + // virtual void set_prior(std::shared_ptr prior_) = 0; //! Set the update algorithm for the current hierarchy. Implemented in the //! BaseHierarchy class diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index d78db2ba6..9ccfd15ff 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -65,17 +65,17 @@ class BaseHierarchy : public AbstractHierarchy { //! Default destructor ~BaseHierarchy() = default; - //! Set the likelihood for the current hierarchy - void set_likelihood(std::shared_ptr like_) override { + //! Sets the likelihood for the current hierarchy + void set_likelihood(std::shared_ptr like_) /*override*/ { like = std::static_pointer_cast(like_); } - //! Set the prior model for the current hierarchy - void set_prior(std::shared_ptr prior_) override { + //! Sets the prior model for the current hierarchy + void set_prior(std::shared_ptr prior_) /*override*/ { prior = std::static_pointer_cast(prior_); } - //! Set the update algorithm for the current hierarchy + //! Sets the update algorithm for the current hierarchy void set_updater(std::shared_ptr updater_) override { updater = updater_; }; From 97008892cf13018746d5ee21db2c24a54bae25d2 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 22 Mar 2022 14:33:24 +0100 Subject: [PATCH 218/317] Remove hypers.card leftovers --- src/hierarchies/fa_hierarchy.h | 2 +- src/hierarchies/priors/hyperparams.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index f68ba3fdd..2fb0bd7a3 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -40,7 +40,7 @@ class FAHierarchy State::FA state; state.mu = hypers.mutilde; state.psi = hypers.beta / (hypers.alpha0 + 1.); - state.eta = Eigen::MatrixXd::Zero(hypers.card, hypers.q); + // state.eta = Eigen::MatrixXd::Zero(hypers.card, hypers.q); state.lambda = Eigen::MatrixXd::Zero(dim, hypers.q); state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); like->set_state(state); diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h index 04dcbd84b..acdcd8f7a 100644 --- a/src/hierarchies/priors/hyperparams.h +++ b/src/hierarchies/priors/hyperparams.h @@ -28,7 +28,7 @@ struct MNIG { struct FA { Eigen::VectorXd mutilde, beta; double phi, alpha0; - unsigned int card, q; + unsigned int /*card,*/ q; }; } // namespace Hyperparams From fdb06b2b0b72d85d19ad874f344d57509160dcd1 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 22 Mar 2022 14:42:21 +0100 Subject: [PATCH 219/317] Small doc change --- src/hierarchies/likelihoods/base_likelihood.h | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 0a471a5d6..369334400 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -92,13 +92,12 @@ class BaseLikelihood : public AbstractLikelihood { static_cast(*this), unconstrained_params, 0); } - //! Evaluates the log likelihood over all the data in the cluster - //! given unconstrained parameter values. - //! By unconstrained parameters we mean that each entry of - //! the parameter vector can range over (-inf, inf). - //! Usually, some kind of transformation is required from the unconstrained - //! parameterization to the actual parameterization. This version using - //! `stan::math::var` type is required for Stan automatic aifferentiation. + //! This version using `stan::math::var` type is required for Stan automatic + //! differentiation. Evaluates the log likelihood over all the data in the + //! cluster given unconstrained parameter values. By unconstrained parameters + //! we mean that each entry of the parameter vector can range over (-inf, + //! inf). Usually, some kind of transformation is required from the + //! unconstrained parameterization to the actual parameterization. //! @param unconstrained_params vector collecting the unconstrained //! parameters //! @return The evaluation of the log likelihood over all data in the cluster @@ -214,8 +213,9 @@ void BaseLikelihood::add_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) { assert(cluster_data_idx.find(id) == cluster_data_idx.end()); - card += 1; - log_card = std::log(card); + set_card(++card); + // card += 1; + // log_card = std::log(card); static_cast(this)->update_summary_statistics(datum, covariate, true); cluster_data_idx.insert(id); @@ -227,7 +227,7 @@ void BaseLikelihood::remove_datum( const Eigen::RowVectorXd &covariate) { static_cast(this)->update_summary_statistics(datum, covariate, false); - set_card(card - 1); + set_card(--card); auto it = cluster_data_idx.find(id); assert(it != cluster_data_idx.end()); cluster_data_idx.erase(it); From 8fc0c9c1c913f948860f92cb56bf7fade18e6ed8 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 22 Mar 2022 14:43:00 +0100 Subject: [PATCH 220/317] FAHierarchy relies on dataset_ptr --- .../likelihoods/laplace_likelihood.cc | 35 ++++++++++--------- .../likelihoods/laplace_likelihood.h | 17 +++++---- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/hierarchies/likelihoods/laplace_likelihood.cc b/src/hierarchies/likelihoods/laplace_likelihood.cc index 61bad8429..3c29632c8 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.cc +++ b/src/hierarchies/likelihoods/laplace_likelihood.cc @@ -5,18 +5,19 @@ double LaplaceLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { datum(0), state.mean, stan::math::sqrt(state.var / 2.0)); } -void LaplaceLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, - bool add) { - if (add) { - // sum_abs_diff_curr += std::abs(state.mean - datum(0, 0)); - cluster_data_values.push_back(datum); - } else { - // sum_abs_diff_curr -= std::abs(state.mean - datum(0, 0)); - auto it = std::find(cluster_data_values.begin(), cluster_data_values.end(), - datum); - cluster_data_values.erase(it); - } -} +// void LaplaceLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, +// bool add) { +// if (add) { +// sum_abs_diff_curr += std::abs(state.mean - datum(0, 0)); +// cluster_data_values.push_back(datum); +// } else { +// sum_abs_diff_curr -= std::abs(state.mean - datum(0, 0)); +// auto it = std::find(cluster_data_values.begin(), +// cluster_data_values.end(), +// datum); +// cluster_data_values.erase(it); +// } +// } void LaplaceLikelihood::set_state_from_proto( const google::protobuf::Message &state_, bool update_card) { @@ -34,11 +35,11 @@ LaplaceLikelihood::get_state_proto() const { return out; } -void LaplaceLikelihood::clear_summary_statistics() { - cluster_data_values.clear(); - // sum_abs_diff_curr = 0; - // sum_abs_diff_prop = 0; -} +// void LaplaceLikelihood::clear_summary_statistics() { +// cluster_data_values.clear(); +// sum_abs_diff_curr = 0; +// sum_abs_diff_prop = 0; +// } // double UniNormLikelihood::cluster_lpdf_from_unconstrained( // Eigen::VectorXd unconstrained_params) { diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 62d8fc14c..84d5a200a 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -20,18 +20,20 @@ class LaplaceLikelihood bool is_dependent() const override { return false; }; void set_state_from_proto(const google::protobuf::Message &state_, bool update_card = true) override; - void clear_summary_statistics() override; + void clear_summary_statistics() override { return; }; template T cluster_lpdf_from_unconstrained( const Eigen::Matrix &unconstrained_params) const { assert(unconstrained_params.size() == 2); + T mean = unconstrained_params(0); T var = stan::math::positive_constrain(unconstrained_params(1)); + T out = 0.; - for (auto it = cluster_data_values.begin(); - it != cluster_data_values.end(); ++it) { - out += stan::math::double_exponential_lpdf(*it, mean, + for (auto it = cluster_data_idx.begin(); it != cluster_data_idx.end(); + ++it) { + out += stan::math::double_exponential_lpdf(dataset_ptr->row(*it), mean, stan::math::sqrt(var / 2.0)); } return out; @@ -42,11 +44,12 @@ class LaplaceLikelihood protected: double compute_lpdf(const Eigen::RowVectorXd &datum) const override; - void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override { + return; + }; - // TODO: ORA CHE HO IL DATASET QUESTO NON SERVE! //! Set of values of data points belonging to this cluster - std::list cluster_data_values; + // std::list cluster_data_values; //! Sum of absolute differences for current params // double sum_abs_diff_curr = 0; //! Sum of absolute differences for proposal params From 107d6eb5131aa2de8a78bcc88328d043b46d24c3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 22 Mar 2022 14:43:41 +0100 Subject: [PATCH 221/317] Bug fixed --- src/hierarchies/likelihoods/states/multi_ls_state.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index b83970847..eb97fc5ee 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -35,12 +35,13 @@ multi_ls_to_constrained(Eigen::Matrix in) { return std::make_tuple(mean, prec); } +// SEE GitHub for tests template T multi_ls_log_det_jac( Eigen::Matrix prec_constrained) { T out = 0; - stan::math::positive_constrain(stan::math::cov_matrix_free(prec_constrained), - out); + stan::math::cov_matrix_constrain( + stan::math::cov_matrix_free(prec_constrained), out); return out; } From 6d0b24399a5891d6a0dfdcd68d216bd1b541e28e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 23 Mar 2022 11:16:58 +0100 Subject: [PATCH 222/317] Small notebook changes --- python/notebooks/gaussian_mix_uni.ipynb | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/notebooks/gaussian_mix_uni.ipynb b/python/notebooks/gaussian_mix_uni.ipynb index 6e74c6142..7a9c71ed1 100644 --- a/python/notebooks/gaussian_mix_uni.ipynb +++ b/python/notebooks/gaussian_mix_uni.ipynb @@ -210,24 +210,24 @@ "neal2_algo = \"\"\"\n", "algo_id: \"Neal2\"\n", "rng_seed: 20201124\n", - "iterations: 10\n", - "burnin: 5\n", + "iterations: 100\n", + "burnin: 50\n", "init_num_clusters: 3\n", "\"\"\"\n", "\n", "neal3_algo = \"\"\"\n", "algo_id: \"Neal3\"\n", "rng_seed: 20201124\n", - "iterations: 10\n", - "burnin: 5\n", + "iterations: 100\n", + "burnin: 50\n", "init_num_clusters: 3\n", "\"\"\"\n", "\n", "neal8_algo = \"\"\"\n", "algo_id: \"Neal8\"\n", "rng_seed: 20201124\n", - "iterations: 1000\n", - "burnin: 500\n", + "iterations: 100\n", + "burnin: 50\n", "init_num_clusters: 3\n", "neal8_n_aux: 3\n", "\"\"\"\n", From daa84505b8dd5ab8926fc9cf42e6aaff0acbf6b0 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 23 Mar 2022 16:48:47 +0100 Subject: [PATCH 223/317] Include FAHierarchy --- src/includes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/includes.h b/src/includes.h index 03ea6dcb0..d4d19b8ca 100644 --- a/src/includes.h +++ b/src/includes.h @@ -9,10 +9,10 @@ #include "algorithms/neal8_algorithm.h" #include "collectors/file_collector.h" #include "collectors/memory_collector.h" +#include "hierarchies/fa_hierarchy.h" #include "hierarchies/lapnig_hierarchy.h" #include "hierarchies/lin_reg_uni_hierarchy.h" #include "hierarchies/load_hierarchies.h" -// #include "hierarchies/fa_hierarchy.h" #include "hierarchies/nnig_hierarchy.h" #include "hierarchies/nnw_hierarchy.h" #include "hierarchies/nnxig_hierarchy.h" From 65d99c1f8b92ee1d5a5beab6a8c69863fc38efdf Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 23 Mar 2022 16:49:22 +0100 Subject: [PATCH 224/317] Change in target sources order --- src/hierarchies/updaters/CMakeLists.txt | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 3a35d9f3e..17e4547a9 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,13 +1,9 @@ target_sources(bayesmix PUBLIC abstract_updater.h - conjugate_updater.h - mala_updater.h - metropolis_updater.h + # conjugate_updater.h + semi_conjugate_updater.h nnig_updater.h nnig_updater.cc - random_walk_updater.h - target_lpdf_unconstrained.h - target_lpdf_unconstrained.cc nnxig_updater.h nnxig_updater.cc nnw_updater.h @@ -16,4 +12,9 @@ target_sources(bayesmix PUBLIC mnig_updater.cc fa_updater.h fa_updater.cc + metropolis_updater.h + mala_updater.h + random_walk_updater.h + target_lpdf_unconstrained.h + target_lpdf_unconstrained.cc ) From c431df126ace26dd93e549ae4b6f7f789107b2d9 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Wed, 23 Mar 2022 16:50:27 +0100 Subject: [PATCH 225/317] Posterior_hypers are now managed by updaters --- src/hierarchies/base_hierarchy.h | 33 ++++--- src/hierarchies/lin_reg_uni_hierarchy.h | 22 +++-- src/hierarchies/nnig_hierarchy.h | 11 +-- src/hierarchies/nnw_hierarchy.h | 20 +++-- src/hierarchies/priors/abstract_prior_model.h | 7 +- src/hierarchies/priors/base_prior_model.h | 22 ++--- src/hierarchies/priors/fa_prior_model.cc | 86 +++++++++++++------ src/hierarchies/priors/fa_prior_model.h | 7 +- src/hierarchies/priors/mnig_prior_model.cc | 36 ++++++-- src/hierarchies/priors/mnig_prior_model.h | 6 +- src/hierarchies/priors/nig_prior_model.cc | 29 +++++-- src/hierarchies/priors/nig_prior_model.h | 7 +- src/hierarchies/priors/nw_prior_model.cc | 37 ++++++-- src/hierarchies/priors/nw_prior_model.h | 7 +- src/hierarchies/priors/nxig_prior_model.cc | 22 ++++- src/hierarchies/priors/nxig_prior_model.h | 7 +- src/hierarchies/updaters/abstract_updater.h | 17 +++- src/hierarchies/updaters/fa_updater.cc | 15 ++-- src/hierarchies/updaters/mnig_updater.cc | 84 +++++++++++++----- src/hierarchies/updaters/mnig_updater.h | 13 ++- src/hierarchies/updaters/nnig_updater.cc | 74 ++++++++++++---- src/hierarchies/updaters/nnig_updater.h | 14 ++- src/hierarchies/updaters/nnw_updater.cc | 78 +++++++++++++---- src/hierarchies/updaters/nnw_updater.h | 14 ++- src/hierarchies/updaters/nnxig_updater.cc | 61 ++++++++++--- src/hierarchies/updaters/nnxig_updater.h | 15 ++-- .../updaters/semi_conjugate_updater.h | 66 ++++++++++++++ test/likelihoods.cc | 5 +- test/prior_models.cc | 23 +++-- 29 files changed, 628 insertions(+), 210 deletions(-) create mode 100644 src/hierarchies/updaters/semi_conjugate_updater.h diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 9ccfd15ff..83505135c 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -43,6 +43,7 @@ class BaseHierarchy : public AbstractHierarchy { public: using HyperParams = decltype(prior->get_hypers()); + using ProtoHypers = AbstractUpdater::ProtoHypers; //! Constructor that allows the specification of Likelihood, PriorModel and //! Updater for a given Hierarchy @@ -149,12 +150,12 @@ class BaseHierarchy : public AbstractHierarchy { // ADD EXCEPTION HANDLING //! Public wrapper for `marg_lpdf()` methods double get_marg_lpdf( - const HyperParams ¶ms, const Eigen::RowVectorXd &datum, + const ProtoHypers &hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { if (this->is_dependent()) { - return marg_lpdf(params, datum, covariate); + return marg_lpdf(hier_params, datum, covariate); } else { - return marg_lpdf(params, datum); + return marg_lpdf(hier_params, datum); } } @@ -166,7 +167,7 @@ class BaseHierarchy : public AbstractHierarchy { double prior_pred_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(prior->get_hypers(), datum, covariate); + return get_marg_lpdf(*(prior->get_hypers_proto()), datum, covariate); } // ADD EXCEPTION HANDLING @@ -209,7 +210,8 @@ class BaseHierarchy : public AbstractHierarchy { double conditional_pred_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(prior->get_posterior_hypers(), datum, covariate); + return get_marg_lpdf(updater->compute_posterior_hypers(*like, *prior), + datum, covariate); } // ADD EXCEPTION HANDLING @@ -246,7 +248,8 @@ class BaseHierarchy : public AbstractHierarchy { //! Generates new state values from the centering prior distribution void sample_prior() override { - like->set_state_from_proto(*prior->sample(false), false); + auto hypers = prior->get_hypers_proto(); + like->set_state_from_proto(*prior->sample(*hypers), false); }; //! Generates new state values from the centering posterior distribution @@ -344,7 +347,10 @@ class BaseHierarchy : public AbstractHierarchy { const bool update_params = false, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override { like->add_datum(id, datum, covariate); - if (update_params) updater->compute_posterior_hypers(*like, *prior); + if (update_params) { + updater->save_posterior_hypers( + updater->compute_posterior_hypers(*like, *prior)); + } }; //! Removes a datum and its index from the hierarchy @@ -353,13 +359,18 @@ class BaseHierarchy : public AbstractHierarchy { const bool update_params = false, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) override { like->remove_datum(id, datum, covariate); - if (update_params) updater->compute_posterior_hypers(*like, *prior); + if (update_params) { + updater->save_posterior_hypers( + updater->compute_posterior_hypers(*like, *prior)); + } }; //! Main function that initializes members to appropriate values void initialize() override { prior->initialize(); - if (is_conjugate()) prior->set_posterior_hypers(prior->get_hypers()); + if (is_conjugate()) { + updater->save_posterior_hypers(*prior->get_hypers_proto()); + } initialize_state(); like->clear_data(); like->clear_summary_statistics(); @@ -388,7 +399,7 @@ class BaseHierarchy : public AbstractHierarchy { //! @param params Container of (prior or posterior) hyperparameter values //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf - virtual double marg_lpdf(const HyperParams ¶ms, + virtual double marg_lpdf(const ProtoHypers &hier_params, const Eigen::RowVectorXd &datum) const { if (!is_conjugate()) { throw std::runtime_error( @@ -405,7 +416,7 @@ class BaseHierarchy : public AbstractHierarchy { //! @param datum Point which is to be evaluated //! @param covariate Covariate vector associated to datum //! @return The evaluation of the lpdf - virtual double marg_lpdf(const HyperParams ¶ms, + virtual double marg_lpdf(const ProtoHypers &hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { if (!is_conjugate()) { diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index a76d44f58..b48efaa71 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -44,14 +44,22 @@ class LinRegUniHierarchy like->set_state(state); }; - double marg_lpdf(const HyperParams ¶ms, const Eigen::RowVectorXd &datum, + double marg_lpdf(const ProtoHypers &hier_params, + const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const override { - double sig_n = sqrt( - (1 + (covariate * params.var_scaling_inv * covariate.transpose())(0)) * - params.scale / params.shape); - return stan::math::student_t_lpdf(datum(0), 2 * params.shape, - covariate.dot(params.mean), sig_n); - } + auto params = hier_params.lin_reg_uni_state(); + Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); + Eigen::MatrixXd var_scaling = bayesmix::to_eigen(params.var_scaling()); + + auto I = Eigen::MatrixXd::Identity(prior->get_dim(), prior->get_dim()); + Eigen::MatrixXd var_scaling_inv = var_scaling.llt().solve(I); + + double sig_n = + sqrt((1 + (covariate * var_scaling_inv * covariate.transpose())(0)) * + params.scale() / params.shape()); + return stan::math::student_t_lpdf(datum(0), 2 * params.shape(), + covariate.dot(mean), sig_n); + }; }; #endif // BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index d6f13febc..294f1ee12 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -41,12 +41,13 @@ class NNIGHierarchy like->set_state(state); }; - double marg_lpdf(const HyperParams ¶ms, + double marg_lpdf(const ProtoHypers &hier_params, const Eigen::RowVectorXd &datum) const override { - double sig_n = sqrt(params.scale * (params.var_scaling + 1) / - (params.shape * params.var_scaling)); - return stan::math::student_t_lpdf(datum(0), 2 * params.shape, params.mean, - sig_n); + auto params = hier_params.nnig_state(); + double sig_n = sqrt(params.scale() * (params.var_scaling() + 1) / + (params.shape() * params.var_scaling())); + return stan::math::student_t_lpdf(datum(0), 2 * params.shape(), + params.mean(), sig_n); } }; diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 12d1c543e..9c8c7c565 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -44,9 +44,9 @@ class NNWHierarchy like->set_state(state); }; - double marg_lpdf(const HyperParams ¶ms, + double marg_lpdf(const ProtoHypers &hier_params, const Eigen::RowVectorXd &datum) const override { - HyperParams pred_params = get_predictive_t_parameters(params); + HyperParams pred_params = get_predictive_t_parameters(hier_params); Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); double logdet = 2 * log(diag.array()).sum(); return bayesmix::multi_student_t_invscale_lpdf( @@ -54,15 +54,21 @@ class NNWHierarchy logdet); } - HyperParams get_predictive_t_parameters(const HyperParams ¶ms) const { + HyperParams get_predictive_t_parameters( + const ProtoHypers &hier_params) const { + auto params = hier_params.nnw_state(); // Compute dof and scale of marginal distribution unsigned int dim = like->get_dim(); - double nu_n = params.deg_free - dim + 1; - double coeff = (params.var_scaling + 1) / (params.var_scaling * nu_n); - Eigen::MatrixXd scale_chol_n = params.scale_chol / std::sqrt(coeff); + double nu_n = params.deg_free() - dim + 1; + double coeff = (params.var_scaling() + 1) / (params.var_scaling() * nu_n); + // Eigen::MatrixXd scale = bayesmix::to_eigen(params.scale()); + Eigen::MatrixXd scale_chol = + Eigen::LLT(bayesmix::to_eigen(params.scale())) + .matrixU(); + Eigen::MatrixXd scale_chol_n = scale_chol / std::sqrt(coeff); // Return predictive t parameters HyperParams out; - out.mean = params.mean; + out.mean = bayesmix::to_eigen(params.mean()); out.deg_free = nu_n; out.scale_chol = scale_chol_n; return out; diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 8531264fa..d1f839ad3 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -63,8 +63,11 @@ class AbstractPriorModel { //! @param use_post_hypers It is a `bool` which decides whether to use prior //! or posterior parameters //! @return A Protobuf message storing the state sampled from the prior model + // virtual std::shared_ptr sample( + // bool use_post_hypers) = 0; + virtual std::shared_ptr sample( - bool use_post_hypers) = 0; + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) = 0; //! Updates hyperparameter values given a vector of cluster states virtual void update_hypers( @@ -81,7 +84,6 @@ class AbstractPriorModel { //! pointer. Implemented in BasePriorModel virtual void write_hypers_to_proto(google::protobuf::Message *out) const = 0; - protected: //! Writes current value of hyperparameters to a Protobuf message and //! return a shared_ptr. //! New hierarchies have to first modify the field 'oneof val' in the @@ -89,6 +91,7 @@ class AbstractPriorModel { virtual std::shared_ptr get_hypers_proto() const = 0; + protected: //! Initializes hierarchy hyperparameters to appropriate values virtual void initialize_hypers() = 0; }; diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index cb3f71705..47403f2f5 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -57,12 +57,12 @@ class BasePriorModel : public AbstractPriorModel { static_cast(*this), unconstrained_params, 0); } - //! Evaluates the log likelihood for unconstrained parameter values. - //! By unconstrained parameters we mean that each entry of - //! the parameter vector can range over (-inf, inf). - //! Usually, some kind of transformation is required from the unconstrained - //! parameterization to the actual parameterization. This version using - //! `stan::math::var` type is required for Stan automatic aifferentiation. + //! This version using `stan::math::var` type is required for Stan automatic + //! aifferentiation. Evaluates the log likelihood for unconstrained parameter + //! values. By unconstrained parameters we mean that each entry of the + //! parameter vector can range over (-inf, inf). Usually, some kind of + //! transformation is required from the unconstrained parameterization to the + //! actual parameterization. //! @param unconstrained_params vector collecting the unconstrained //! parameters //! @return The evaluation of the log likelihood of the prior model @@ -86,12 +86,12 @@ class BasePriorModel : public AbstractPriorModel { HyperParams get_hypers() const { return *hypers; } //! Returns the struct of the current posterior hyperparameters - HyperParams get_posterior_hypers() const { return post_hypers; } + // HyperParams get_posterior_hypers() const { return post_hypers; } //! Updates the current value of the posterior hyperparameters - void set_posterior_hypers(const HyperParams &_post_hypers) { - post_hypers = _post_hypers; - }; + // void set_posterior_hypers(const HyperParams &_post_hypers) { + // post_hypers = _post_hypers; + // }; //! Writes current values of the hyperparameters to a Protobuf message by //! pointer @@ -136,7 +136,7 @@ class BasePriorModel : public AbstractPriorModel { std::shared_ptr hypers = std::make_shared(); //! Container for posterior hyperparameters values - HyperParams post_hypers; + // HyperParams post_hypers; //! Pointer to a Protobuf prior object for this class std::shared_ptr prior; diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index 43abee380..4b644d860 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -30,53 +30,87 @@ double FAPriorModel::lpdf(const google::protobuf::Message &state_) { } std::shared_ptr FAPriorModel::sample( - bool use_post_hypers) { + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { // Random seed auto &rng = bayesmix::Rng::Instance().get(); - // Select params to use - Hyperparams::FA params = use_post_hypers ? post_hypers : *hypers; + // Get params to use + auto params = hier_hypers.fa_state(); + Eigen::VectorXd mutilde = bayesmix::to_eigen(params.mutilde()); + Eigen::VectorXd beta = bayesmix::to_eigen(params.beta()); - // HO AGGIUNTO PARAMS.CARD MA NON SO SE SIA LA SCELTA MIGLIORE!!! // Compute output state State::FA out; - out.mu = params.mutilde; - out.psi = params.beta / (params.alpha0 + 1.); - // out.eta = Eigen::MatrixXd::Zero(params.card, params.q); - out.lambda = Eigen::MatrixXd::Zero(dim, params.q); + out.mu = mutilde; + out.psi = beta / (params.alpha0() + 1.); + out.lambda = Eigen::MatrixXd::Zero(dim, params.q()); for (size_t j = 0; j < dim; j++) { - out.mu[j] = - stan::math::normal_rng(params.mutilde[j], sqrt(params.phi), rng); + out.mu[j] = stan::math::normal_rng(mutilde[j], sqrt(params.phi()), rng); - out.psi[j] = stan::math::inv_gamma_rng(params.alpha0, params.beta[j], rng); + out.psi[j] = stan::math::inv_gamma_rng(params.alpha0(), beta[j], rng); - for (size_t i = 0; i < params.q; i++) { + for (size_t i = 0; i < params.q(); i++) { out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); } } - // for (size_t i = 0; i < params.card; i++) { - // for (size_t j = 0; j < params.q; j++) { - // out.eta(i, j) = stan::math::normal_rng(0, 1, rng); - // } - // } - - // Questi conti non li passo al proto, attenzione !!! - // out.psi_inverse = out.psi.cwiseInverse().asDiagonal(); - // compute_wood_factors(out.cov_wood, out.cov_logdet, out.lambda, - // out.psi_inverse); // Eigen2Proto conversion bayesmix::AlgorithmState::ClusterState state; bayesmix::to_proto(out.mu, state.mutable_fa_state()->mutable_mu()); bayesmix::to_proto(out.psi, state.mutable_fa_state()->mutable_psi()); - // bayesmix::to_proto(out.eta, state.mutable_fa_state()->mutable_eta()); bayesmix::to_proto(out.lambda, state.mutable_fa_state()->mutable_lambda()); return std::make_shared(state); - - // MANCA PSI_INVERSE E GLI OUTPUT DA COMPUTE_WOOD_FACTORS !!! BISOGNA - // CAMBIARE IL PROTO } +// std::shared_ptr FAPriorModel::sample( +// bool use_post_hypers) { +// // Random seed +// auto &rng = bayesmix::Rng::Instance().get(); + +// // Select params to use +// Hyperparams::FA params = use_post_hypers ? post_hypers : *hypers; + +// // Compute output state +// State::FA out; +// out.mu = params.mutilde; +// out.psi = params.beta / (params.alpha0 + 1.); +// // out.eta = Eigen::MatrixXd::Zero(params.card, params.q); +// out.lambda = Eigen::MatrixXd::Zero(dim, params.q); +// for (size_t j = 0; j < dim; j++) { +// out.mu[j] = +// stan::math::normal_rng(params.mutilde[j], sqrt(params.phi), rng); + +// out.psi[j] = stan::math::inv_gamma_rng(params.alpha0, params.beta[j], +// rng); + +// for (size_t i = 0; i < params.q; i++) { +// out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); +// } +// } +// // for (size_t i = 0; i < params.card; i++) { +// // for (size_t j = 0; j < params.q; j++) { +// // out.eta(i, j) = stan::math::normal_rng(0, 1, rng); +// // } +// // } + +// // Questi conti non li passo al proto, attenzione !!! +// // out.psi_inverse = out.psi.cwiseInverse().asDiagonal(); +// // compute_wood_factors(out.cov_wood, out.cov_logdet, out.lambda, +// // out.psi_inverse); + +// // Eigen2Proto conversion +// bayesmix::AlgorithmState::ClusterState state; +// bayesmix::to_proto(out.mu, state.mutable_fa_state()->mutable_mu()); +// bayesmix::to_proto(out.psi, state.mutable_fa_state()->mutable_psi()); +// // bayesmix::to_proto(out.eta, state.mutable_fa_state()->mutable_eta()); +// bayesmix::to_proto(out.lambda, +// state.mutable_fa_state()->mutable_lambda()); return +// std::make_shared(state); + +// // MANCA PSI_INVERSE E GLI OUTPUT DA COMPUTE_WOOD_FACTORS !!! BISOGNA +// // CAMBIARE IL PROTO +// } + void FAPriorModel::update_hypers( const std::vector &states) { auto &rng = bayesmix::Rng::Instance().get(); diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index cd4d6eea4..c8e7ab0fd 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -22,7 +22,10 @@ class FAPriorModel double lpdf(const google::protobuf::Message &state_) override; std::shared_ptr sample( - bool use_post_hypers) override; + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; + + // std::shared_ptr sample( + // bool use_post_hypers) override; void update_hypers(const std::vector &states) override; @@ -32,10 +35,10 @@ class FAPriorModel unsigned int get_dim() const { return dim; }; - protected: std::shared_ptr get_hypers_proto() const override; + protected: void initialize_hypers() override; unsigned int dim; diff --git a/src/hierarchies/priors/mnig_prior_model.cc b/src/hierarchies/priors/mnig_prior_model.cc index 34041a30a..e6d036674 100644 --- a/src/hierarchies/priors/mnig_prior_model.cc +++ b/src/hierarchies/priors/mnig_prior_model.cc @@ -12,25 +12,47 @@ double MNIGPriorModel::lpdf(const google::protobuf::Message &state_) { } std::shared_ptr MNIGPriorModel::sample( - bool use_post_hypers) { + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - Hyperparams::MNIG params = use_post_hypers ? post_hypers : *hypers; - double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - Eigen::VectorXd regression_coeffs = stan::math::multi_normal_prec_rng( - params.mean, params.var_scaling / var, rng); + auto params = hier_hypers.lin_reg_uni_state(); + Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); + Eigen::MatrixXd var_scaling = bayesmix::to_eigen(params.var_scaling()); + + double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + Eigen::VectorXd regression_coeffs = + stan::math::multi_normal_prec_rng(mean, var_scaling / var, rng); bayesmix::AlgorithmState::ClusterState state; - // bayesmix::Vector regression_coeffs_proto; bayesmix::to_proto( regression_coeffs, state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()); - // state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()->CopyFrom(regression_coeffs_proto); state.mutable_lin_reg_uni_ls_state()->set_var(var); return std::make_shared(state); } +// std::shared_ptr MNIGPriorModel::sample( +// bool use_post_hypers) { +// auto &rng = bayesmix::Rng::Instance().get(); +// Hyperparams::MNIG params = use_post_hypers ? post_hypers : *hypers; + +// double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); +// Eigen::VectorXd regression_coeffs = stan::math::multi_normal_prec_rng( +// params.mean, params.var_scaling / var, rng); + +// bayesmix::AlgorithmState::ClusterState state; +// // bayesmix::Vector regression_coeffs_proto; +// bayesmix::to_proto( +// regression_coeffs, +// state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()); +// // +// state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()->CopyFrom(regression_coeffs_proto); +// state.mutable_lin_reg_uni_ls_state()->set_var(var); + +// return std::make_shared(state); +// } + void MNIGPriorModel::update_hypers( const std::vector &states) { if (prior->has_fixed_values()) { diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index 6502849d5..b307675a0 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -22,7 +22,9 @@ class MNIGPriorModel : public BasePriorModel sample( - bool use_post_hypers) override; + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; + // std::shared_ptr sample( + // bool use_post_hypers) override; void update_hypers(const std::vector &states) override; @@ -32,10 +34,10 @@ class MNIGPriorModel : public BasePriorModel get_hypers_proto() const override; + protected: void initialize_hypers() override; unsigned int dim; diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index 67009462c..5cb797c08 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -84,19 +84,36 @@ double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } +// std::shared_ptr NIGPriorModel::sample( +// bool use_post_hypers) { +// auto &rng = bayesmix::Rng::Instance().get(); +// Hyperparams::NIG params = use_post_hypers ? post_hypers : *hypers; +// double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); +// double mean = +// stan::math::normal_rng(params.mean, sqrt(var / params.var_scaling), +// rng); + +// bayesmix::AlgorithmState::ClusterState state; +// state.mutable_uni_ls_state()->set_mean(mean); +// state.mutable_uni_ls_state()->set_var(var); +// return std::make_shared(state); +// }; + std::shared_ptr NIGPriorModel::sample( - bool use_post_hypers) { + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - Hyperparams::NIG params = use_post_hypers ? post_hypers : *hypers; - double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - double mean = - stan::math::normal_rng(params.mean, sqrt(var / params.var_scaling), rng); + auto params = hier_hypers.nnig_state(); + + // Hyperparams::NIG params = use_post_hypers ? post_hypers : *hypers; + double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + double mean = stan::math::normal_rng(params.mean(), + sqrt(var / params.var_scaling()), rng); bayesmix::AlgorithmState::ClusterState state; state.mutable_uni_ls_state()->set_mean(mean); state.mutable_uni_ls_state()->set_var(var); return std::make_shared(state); -}; +} void NIGPriorModel::update_hypers( const std::vector &states) { diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index 0cda78004..e791eaa11 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -37,8 +37,11 @@ class NIGPriorModel : public BasePriorModel sample( + // bool use_post_hypers) override; + std::shared_ptr sample( - bool use_post_hypers) override; + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; void update_hypers(const std::vector &states) override; @@ -46,10 +49,10 @@ class NIGPriorModel : public BasePriorModel get_hypers_proto() const override; + protected: void initialize_hypers() override; }; diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc index f0bb2dfe7..ebdf73910 100644 --- a/src/hierarchies/priors/nw_prior_model.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -122,18 +122,19 @@ double NWPriorModel::lpdf(const google::protobuf::Message &state_) { } std::shared_ptr NWPriorModel::sample( - bool use_post_hypers) { + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - - Hyperparams::NW params = use_post_hypers ? post_hypers : *hypers; + auto params = hier_hypers.nnw_state(); + Eigen::MatrixXd scale = bayesmix::to_eigen(params.scale()); + Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); Eigen::MatrixXd tau_new = - stan::math::wishart_rng(params.deg_free, params.scale, rng); + stan::math::wishart_rng(params.deg_free(), scale, rng); // Update state State::MultiLS out; out.mean = stan::math::multi_normal_prec_rng( - params.mean, tau_new * params.var_scaling, rng); + mean, tau_new * params.var_scaling(), rng); write_prec_to_state(tau_new, &out); // Make output state @@ -145,6 +146,32 @@ std::shared_ptr NWPriorModel::sample( return std::make_shared(state); }; +// std::shared_ptr NWPriorModel::sample( +// bool use_post_hypers) { +// auto &rng = bayesmix::Rng::Instance().get(); + +// Hyperparams::NW params = use_post_hypers ? post_hypers : *hypers; + +// Eigen::MatrixXd tau_new = +// stan::math::wishart_rng(params.deg_free, params.scale, rng); + +// // Update state +// State::MultiLS out; +// out.mean = stan::math::multi_normal_prec_rng( +// params.mean, tau_new * params.var_scaling, rng); +// write_prec_to_state(tau_new, &out); + +// // Make output state +// bayesmix::AlgorithmState::ClusterState state; +// bayesmix::to_proto(out.mean, +// state.mutable_multi_ls_state()->mutable_mean()); +// bayesmix::to_proto(out.prec, +// state.mutable_multi_ls_state()->mutable_prec()); +// bayesmix::to_proto(out.prec_chol, +// state.mutable_multi_ls_state()->mutable_prec_chol()); +// return std::make_shared(state); +// }; + void NWPriorModel::update_hypers( const std::vector &states) { auto &rng = bayesmix::Rng::Instance().get(); diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h index 9fd349bc1..cb393ee7f 100644 --- a/src/hierarchies/priors/nw_prior_model.h +++ b/src/hierarchies/priors/nw_prior_model.h @@ -24,7 +24,10 @@ class NWPriorModel : public BasePriorModel sample( - bool use_post_hypers) override; + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; + + // std::shared_ptr sample( + // bool use_post_hypers) override; void update_hypers(const std::vector &states) override; @@ -36,10 +39,10 @@ class NWPriorModel : public BasePriorModel get_hypers_proto() const override; + protected: void initialize_hypers() override; unsigned int dim; diff --git a/src/hierarchies/priors/nxig_prior_model.cc b/src/hierarchies/priors/nxig_prior_model.cc index eec6f3871..5c7a953aa 100644 --- a/src/hierarchies/priors/nxig_prior_model.cc +++ b/src/hierarchies/priors/nxig_prior_model.cc @@ -31,12 +31,12 @@ double NxIGPriorModel::lpdf(const google::protobuf::Message &state_) { } std::shared_ptr NxIGPriorModel::sample( - bool use_post_hypers) { + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - Hyperparams::NxIG params = use_post_hypers ? post_hypers : *hypers; + auto params = hier_hypers.nnxig_state(); - double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); - double mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); + double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + double mean = stan::math::normal_rng(params.mean(), sqrt(params.var()), rng); bayesmix::AlgorithmState::ClusterState state; state.mutable_uni_ls_state()->set_mean(mean); @@ -44,6 +44,20 @@ std::shared_ptr NxIGPriorModel::sample( return std::make_shared(state); }; +// std::shared_ptr NxIGPriorModel::sample( +// bool use_post_hypers) { +// auto &rng = bayesmix::Rng::Instance().get(); +// Hyperparams::NxIG params = use_post_hypers ? post_hypers : *hypers; + +// double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); +// double mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); + +// bayesmix::AlgorithmState::ClusterState state; +// state.mutable_uni_ls_state()->set_mean(mean); +// state.mutable_uni_ls_state()->set_var(var); +// return std::make_shared(state); +// }; + void NxIGPriorModel::update_hypers( const std::vector &states) { if (prior->has_fixed_values()) { diff --git a/src/hierarchies/priors/nxig_prior_model.h b/src/hierarchies/priors/nxig_prior_model.h index 85ade8612..bb5f33052 100644 --- a/src/hierarchies/priors/nxig_prior_model.h +++ b/src/hierarchies/priors/nxig_prior_model.h @@ -23,8 +23,11 @@ class NxIGPriorModel : public BasePriorModel sample( + // bool use_post_hypers) override; + std::shared_ptr sample( - bool use_post_hypers) override; + bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; void update_hypers(const std::vector &states) override; @@ -32,10 +35,10 @@ class NxIGPriorModel : public BasePriorModel get_hypers_proto() const override; + protected: void initialize_hypers() override; }; diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 4c412397e..0983b0146 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -7,6 +7,9 @@ class AbstractUpdater { public: + // Type alias + using ProtoHypers = bayesmix::AlgorithmState::HierarchyHypers; + //! Default destructor virtual ~AbstractUpdater() = default; @@ -23,10 +26,16 @@ class AbstractUpdater { //! Computes the posterior hyperparameters required for the sampling in case //! of conjugate hierarchies - virtual void compute_posterior_hypers(AbstractLikelihood &like, - AbstractPriorModel &prior) { - throw std::runtime_error( - "compute_posterior_hypers() not implemented for this updater"); + virtual ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) { + throw(std::runtime_error( + "compute_posterior_hypers() not implemented for this updater")); + } + + //! Stores the posterior hyperparameters in an appropriate container + virtual void save_posterior_hypers(const ProtoHypers &post_hypers_) { + throw(std::runtime_error( + "save_posterior_hypers() not implemented for this updater")); } }; diff --git a/src/hierarchies/updaters/fa_updater.cc b/src/hierarchies/updaters/fa_updater.cc index 3f8f81694..8fb9f22b5 100644 --- a/src/hierarchies/updaters/fa_updater.cc +++ b/src/hierarchies/updaters/fa_updater.cc @@ -10,8 +10,8 @@ void FAUpdater::draw(AbstractLikelihood& like, AbstractPriorModel& prior, // Sample from the full conditional of the fa hierarchy bool set_card = true, use_post_hypers = true; if (likecast.get_card() == 0) { - likecast.set_state_from_proto(*priorcast.sample(!use_post_hypers), - !set_card); + auto prior_params = *(priorcast.get_hypers_proto()); + likecast.set_state_from_proto(*priorcast.sample(prior_params), !set_card); } else { // Get state and hypers State::FA new_state = likecast.get_state(); @@ -19,10 +19,9 @@ void FAUpdater::draw(AbstractLikelihood& like, AbstractPriorModel& prior, // Gibbs update sample_eta(new_state, hypers, likecast); sample_mu(new_state, hypers, likecast); - // sample_psi(new_state, hypers, likecast.get_dataset(), - // likecast.get_data_idx(), priorcast.get_dim()); sample_lambda(new_state, - // hypers, likecast.get_dataset(), likecast.get_data_idx(), - // priorcast.get_dim()); Eigen2Proto conversion + sample_psi(new_state, hypers, likecast); + sample_lambda(new_state, hypers, likecast); + // Eigen2Proto conversion bayesmix::AlgorithmState::ClusterState new_state_proto; bayesmix::to_proto(new_state.eta, new_state_proto.mutable_fa_state()->mutable_eta()); @@ -45,13 +44,11 @@ void FAUpdater::sample_eta(State::FA& state, const Hyperparams::FA& hypers, auto cluster_data_idx = like.get_data_idx(); unsigned int card = like.get_card(); // eta update + state.eta = Eigen::MatrixXd::Zero(card, hypers.q); auto sigma_eta_inv_llt = (Eigen::MatrixXd::Identity(hypers.q, hypers.q) + state.lambda.transpose() * state.psi_inverse * state.lambda) .llt(); - if (state.eta.rows() != card) { - state.eta = Eigen::MatrixXd::Zero(card, state.eta.cols()); - } Eigen::MatrixXd temp_product( sigma_eta_inv_llt.solve(state.lambda.transpose() * state.psi_inverse)); auto iterator = cluster_data_idx.begin(); diff --git a/src/hierarchies/updaters/mnig_updater.cc b/src/hierarchies/updaters/mnig_updater.cc index 2d1280094..23d7a8627 100644 --- a/src/hierarchies/updaters/mnig_updater.cc +++ b/src/hierarchies/updaters/mnig_updater.cc @@ -1,7 +1,7 @@ #include "mnig_updater.h" -void MNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) { +AbstractUpdater::ProtoHypers MNIGUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); auto& priorcast = downcast_prior(prior); @@ -16,24 +16,68 @@ void MNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, // No update possible if (card == 0) { - priorcast.set_posterior_hypers(hypers); - return; + return *(priorcast.get_hypers_proto()); } // Compute posterior hyperparameters - Hyperparams::MNIG post_params; - post_params.var_scaling = covar_sum_squares + hypers.var_scaling; - auto llt = post_params.var_scaling.llt(); - post_params.var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, dim)); - post_params.mean = llt.solve(mixed_prod + hypers.var_scaling * hypers.mean); - post_params.shape = hypers.shape + 0.5 * card; - post_params.scale = - hypers.scale + - 0.5 * (data_sum_squares + - hypers.mean.transpose() * hypers.var_scaling * hypers.mean - - post_params.mean.transpose() * post_params.var_scaling * - post_params.mean); - - priorcast.set_posterior_hypers(post_params); - return; -}; + Eigen::VectorXd mean; + Eigen::MatrixXd var_scaling, var_scaling_inv; + double shape, scale; + + var_scaling = covar_sum_squares + hypers.var_scaling; + auto llt = var_scaling.llt(); + // var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, dim)); + mean = llt.solve(mixed_prod + hypers.var_scaling * hypers.mean); + shape = hypers.shape + 0.5 * card; + scale = hypers.scale + + 0.5 * (data_sum_squares + + hypers.mean.transpose() * hypers.var_scaling * hypers.mean - + mean.transpose() * var_scaling * mean); + + // Proto conversion + ProtoHypers out; + bayesmix::to_proto(mean, out.mutable_lin_reg_uni_state()->mutable_mean()); + bayesmix::to_proto(var_scaling, + out.mutable_lin_reg_uni_state()->mutable_var_scaling()); + out.mutable_lin_reg_uni_state()->set_shape(shape); + out.mutable_lin_reg_uni_state()->set_scale(scale); + return out; +} + +// void MNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, +// AbstractPriorModel& prior) { +// // Likelihood and Prior downcast +// auto& likecast = downcast_likelihood(like); +// auto& priorcast = downcast_prior(prior); + +// // Getting required quantities from likelihood and prior +// int card = likecast.get_card(); +// unsigned int dim = likecast.get_dim(); +// double data_sum_squares = likecast.get_data_sum_squares(); +// Eigen::MatrixXd covar_sum_squares = likecast.get_covar_sum_squares(); +// Eigen::MatrixXd mixed_prod = likecast.get_mixed_prod(); +// auto hypers = priorcast.get_hypers(); + +// // No update possible +// if (card == 0) { +// priorcast.set_posterior_hypers(hypers); +// return; +// } + +// // Compute posterior hyperparameters +// Hyperparams::MNIG post_params; +// post_params.var_scaling = covar_sum_squares + hypers.var_scaling; +// auto llt = post_params.var_scaling.llt(); +// post_params.var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, +// dim)); post_params.mean = llt.solve(mixed_prod + hypers.var_scaling * +// hypers.mean); post_params.shape = hypers.shape + 0.5 * card; +// post_params.scale = +// hypers.scale + +// 0.5 * (data_sum_squares + +// hypers.mean.transpose() * hypers.var_scaling * hypers.mean - +// post_params.mean.transpose() * post_params.var_scaling * +// post_params.mean); + +// priorcast.set_posterior_hypers(post_params); +// return; +// }; diff --git a/src/hierarchies/updaters/mnig_updater.h b/src/hierarchies/updaters/mnig_updater.h index 6a58174bd..fd1c17214 100644 --- a/src/hierarchies/updaters/mnig_updater.h +++ b/src/hierarchies/updaters/mnig_updater.h @@ -1,18 +1,23 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ -#include "conjugate_updater.h" +#include "semi_conjugate_updater.h" #include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h" #include "src/hierarchies/priors/mnig_prior_model.h" class MNIGUpdater - : public ConjugateUpdater { + : public SemiConjugateUpdater { public: MNIGUpdater() = default; ~MNIGUpdater() = default; - void compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) override; + bool is_conjugate() const override { return true; }; + + ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; + + // void compute_posterior_hypers(AbstractLikelihood& like, + // AbstractPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index 9bd709877..f06c1ba5d 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -3,8 +3,8 @@ #include "src/hierarchies/likelihoods/states/includes.h" #include "src/hierarchies/priors/hyperparams.h" -void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) { +AbstractUpdater::ProtoHypers NNIGUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); auto& priorcast = downcast_prior(prior); @@ -17,22 +17,64 @@ void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, // No update possible if (card == 0) { - priorcast.set_posterior_hypers(hypers); - return; + return *(priorcast.get_hypers_proto()); } // Compute posterior hyperparameters - Hyperparams::NIG post_params; + double mean, var_scaling, shape, scale; + // Hyperparams::NIG post_params; double y_bar = data_sum / (1.0 * card); // sample mean double ss = data_sum_squares - card * y_bar * y_bar; - post_params.mean = (hypers.var_scaling * hypers.mean + data_sum) / - (hypers.var_scaling + card); - post_params.var_scaling = hypers.var_scaling + card; - post_params.shape = hypers.shape + 0.5 * card; - post_params.scale = hypers.scale + 0.5 * ss + - 0.5 * hypers.var_scaling * card * (y_bar - hypers.mean) * - (y_bar - hypers.mean) / (card + hypers.var_scaling); - - priorcast.set_posterior_hypers(post_params); - return; -}; + mean = (hypers.var_scaling * hypers.mean + data_sum) / + (hypers.var_scaling + card); + var_scaling = hypers.var_scaling + card; + shape = hypers.shape + 0.5 * card; + scale = hypers.scale + 0.5 * ss + + 0.5 * hypers.var_scaling * card * (y_bar - hypers.mean) * + (y_bar - hypers.mean) / (card + hypers.var_scaling); + + // Proto conversion + ProtoHypers out; + out.mutable_nnig_state()->set_mean(mean); + out.mutable_nnig_state()->set_var_scaling(var_scaling); + out.mutable_nnig_state()->set_shape(shape); + out.mutable_nnig_state()->set_scale(scale); + // priorcast.set_posterior_hypers(post_params); + return out; +} + +// void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, +// AbstractPriorModel& prior) { +// // Likelihood and Prior downcast +// auto& likecast = downcast_likelihood(like); +// auto& priorcast = downcast_prior(prior); + +// // Getting required quantities from likelihood and prior +// int card = likecast.get_card(); +// double data_sum = likecast.get_data_sum(); +// double data_sum_squares = likecast.get_data_sum_squares(); +// auto hypers = priorcast.get_hypers(); + +// // No update possible +// if (card == 0) { +// priorcast.set_posterior_hypers(hypers); +// return; +// } + +// // Compute posterior hyperparameters +// Hyperparams::NIG post_params; +// double y_bar = data_sum / (1.0 * card); // sample mean +// double ss = data_sum_squares - card * y_bar * y_bar; +// post_params.mean = (hypers.var_scaling * hypers.mean + data_sum) / +// (hypers.var_scaling + card); +// post_params.var_scaling = hypers.var_scaling + card; +// post_params.shape = hypers.shape + 0.5 * card; +// post_params.scale = hypers.scale + 0.5 * ss + +// 0.5 * hypers.var_scaling * card * (y_bar - +// hypers.mean) * +// (y_bar - hypers.mean) / (card + +// hypers.var_scaling); + +// priorcast.set_posterior_hypers(post_params); +// return; +// }; diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index 02d6bdce8..cc9e18bfe 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -1,17 +1,23 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ -#include "conjugate_updater.h" +#include "semi_conjugate_updater.h" #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/hierarchies/priors/nig_prior_model.h" -class NNIGUpdater : public ConjugateUpdater { +class NNIGUpdater + : public SemiConjugateUpdater { public: NNIGUpdater() = default; ~NNIGUpdater() = default; - void compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) override; + bool is_conjugate() const override { return true; }; + + ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; + + // void compute_posterior_hypers(AbstractLikelihood& like, + // AbstractPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnw_updater.cc b/src/hierarchies/updaters/nnw_updater.cc index ad43fe6fc..bf8b77d62 100644 --- a/src/hierarchies/updaters/nnw_updater.cc +++ b/src/hierarchies/updaters/nnw_updater.cc @@ -1,10 +1,12 @@ #include "nnw_updater.h" +#include "algorithm_state.pb.h" #include "src/hierarchies/likelihoods/states/includes.h" #include "src/hierarchies/priors/hyperparams.h" +#include "src/utils/proto_utils.h" -void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) { +AbstractUpdater::ProtoHypers NNWUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); auto& priorcast = downcast_prior(prior); @@ -17,26 +19,70 @@ void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, // No update possible if (card == 0) { - priorcast.set_posterior_hypers(hypers); - return; + return *(prior.get_hypers_proto()); } // Compute posterior hyperparameters - Hyperparams::NW post_params; - post_params.var_scaling = hypers.var_scaling + card; - post_params.deg_free = hypers.deg_free + card; + Eigen::VectorXd mean; + double var_scaling, deg_free; + Eigen::MatrixXd scale, scale_inv, scale_chol; + var_scaling = hypers.var_scaling + card; + deg_free = hypers.deg_free + card; Eigen::VectorXd mubar = data_sum.array() / card; // sample mean - post_params.mean = (hypers.var_scaling * hypers.mean + card * mubar) / - (hypers.var_scaling + card); + mean = (hypers.var_scaling * hypers.mean + card * mubar) / + (hypers.var_scaling + card); // Compute tau_n Eigen::MatrixXd tau_temp = data_sum_squares - card * mubar * mubar.transpose(); tau_temp += (card * hypers.var_scaling / (card + hypers.var_scaling)) * (mubar - hypers.mean) * (mubar - hypers.mean).transpose(); - post_params.scale_inv = tau_temp + hypers.scale_inv; - post_params.scale = stan::math::inverse_spd(post_params.scale_inv); - post_params.scale_chol = - Eigen::LLT(post_params.scale).matrixU(); - priorcast.set_posterior_hypers(post_params); - return; -}; + scale_inv = tau_temp + hypers.scale_inv; + scale = stan::math::inverse_spd(scale_inv); + // scale_chol = Eigen::LLT(scale).matrixU(); + + // Proto conversion + bayesmix::AlgorithmState::HierarchyHypers out; + bayesmix::to_proto(mean, out.mutable_nnw_state()->mutable_mean()); + out.mutable_nnw_state()->set_var_scaling(var_scaling); + out.mutable_nnw_state()->set_deg_free(deg_free); + bayesmix::to_proto(scale, out.mutable_nnw_state()->mutable_scale()); + return out; +} + +// void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, +// AbstractPriorModel& prior) { +// // Likelihood and Prior downcast +// auto& likecast = downcast_likelihood(like); +// auto& priorcast = downcast_prior(prior); + +// // Getting required quantities from likelihood and prior +// int card = likecast.get_card(); +// Eigen::VectorXd data_sum = likecast.get_data_sum(); +// Eigen::MatrixXd data_sum_squares = likecast.get_data_sum_squares(); +// auto hypers = priorcast.get_hypers(); + +// // No update possible +// if (card == 0) { +// priorcast.set_posterior_hypers(hypers); +// return; +// } + +// // Compute posterior hyperparameters +// Hyperparams::NW post_params; +// post_params.var_scaling = hypers.var_scaling + card; +// post_params.deg_free = hypers.deg_free + card; +// Eigen::VectorXd mubar = data_sum.array() / card; // sample mean +// post_params.mean = (hypers.var_scaling * hypers.mean + card * mubar) / +// (hypers.var_scaling + card); +// // Compute tau_n +// Eigen::MatrixXd tau_temp = +// data_sum_squares - card * mubar * mubar.transpose(); +// tau_temp += (card * hypers.var_scaling / (card + hypers.var_scaling)) * +// (mubar - hypers.mean) * (mubar - hypers.mean).transpose(); +// post_params.scale_inv = tau_temp + hypers.scale_inv; +// post_params.scale = stan::math::inverse_spd(post_params.scale_inv); +// post_params.scale_chol = +// Eigen::LLT(post_params.scale).matrixU(); +// priorcast.set_posterior_hypers(post_params); +// return; +// }; diff --git a/src/hierarchies/updaters/nnw_updater.h b/src/hierarchies/updaters/nnw_updater.h index 18f6a824f..357022e23 100644 --- a/src/hierarchies/updaters/nnw_updater.h +++ b/src/hierarchies/updaters/nnw_updater.h @@ -1,17 +1,23 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ -#include "conjugate_updater.h" +#include "semi_conjugate_updater.h" #include "src/hierarchies/likelihoods/multi_norm_likelihood.h" #include "src/hierarchies/priors/nw_prior_model.h" -class NNWUpdater : public ConjugateUpdater { +class NNWUpdater + : public SemiConjugateUpdater { public: NNWUpdater() = default; ~NNWUpdater() = default; - void compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) override; + bool is_conjugate() const override { return true; }; + + ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; + + // void compute_posterior_hypers(AbstractLikelihood& like, + // AbstractPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnxig_updater.cc b/src/hierarchies/updaters/nnxig_updater.cc index c94c8ba3a..a4fa33ab7 100644 --- a/src/hierarchies/updaters/nnxig_updater.cc +++ b/src/hierarchies/updaters/nnxig_updater.cc @@ -3,8 +3,8 @@ #include "src/hierarchies/likelihoods/states/includes.h" #include "src/hierarchies/priors/hyperparams.h" -void NNxIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) { +AbstractUpdater::ProtoHypers NNxIGUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); auto& priorcast = downcast_prior(prior); @@ -18,18 +18,55 @@ void NNxIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, // No update possible if (card == 0) { - priorcast.set_posterior_hypers(hypers); + return *(priorcast.get_hypers_proto()); } // Compute posterior hyperparameters - Hyperparams::NxIG post_params; + double mean, var, shape, scale; double var_y = data_sum_squares - 2 * state.mean * data_sum + card * state.mean * state.mean; - post_params.mean = (hypers.var * data_sum + state.var * hypers.mean) / - (card * hypers.var + state.var); - post_params.var = (state.var * hypers.var) / (card * hypers.var + state.var); - post_params.shape = hypers.shape + 0.5 * card; - post_params.scale = hypers.scale + 0.5 * var_y; - priorcast.set_posterior_hypers(post_params); - return; -}; + mean = (hypers.var * data_sum + state.var * hypers.mean) / + (card * hypers.var + state.var); + var = (state.var * hypers.var) / (card * hypers.var + state.var); + shape = hypers.shape + 0.5 * card; + scale = hypers.scale + 0.5 * var_y; + + // Proto conversion + ProtoHypers out; + out.mutable_nnxig_state()->set_mean(mean); + out.mutable_nnxig_state()->set_var(var); + out.mutable_nnxig_state()->set_shape(shape); + out.mutable_nnxig_state()->set_scale(scale); + return out; +} + +// void NNxIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, +// AbstractPriorModel& prior) { +// // Likelihood and Prior downcast +// auto& likecast = downcast_likelihood(like); +// auto& priorcast = downcast_prior(prior); + +// // Getting required quantities from likelihood and prior +// auto state = likecast.get_state(); +// int card = likecast.get_card(); +// double data_sum = likecast.get_data_sum(); +// double data_sum_squares = likecast.get_data_sum_squares(); +// auto hypers = priorcast.get_hypers(); + +// // No update possible +// if (card == 0) { +// priorcast.set_posterior_hypers(hypers); +// } + +// // Compute posterior hyperparameters +// Hyperparams::NxIG post_params; +// double var_y = data_sum_squares - 2 * state.mean * data_sum + +// card * state.mean * state.mean; +// post_params.mean = (hypers.var * data_sum + state.var * hypers.mean) / +// (card * hypers.var + state.var); +// post_params.var = (state.var * hypers.var) / (card * hypers.var + +// state.var); post_params.shape = hypers.shape + 0.5 * card; +// post_params.scale = hypers.scale + 0.5 * var_y; +// priorcast.set_posterior_hypers(post_params); +// return; +// }; diff --git a/src/hierarchies/updaters/nnxig_updater.h b/src/hierarchies/updaters/nnxig_updater.h index 15e680db8..3200ee2df 100644 --- a/src/hierarchies/updaters/nnxig_updater.h +++ b/src/hierarchies/updaters/nnxig_updater.h @@ -1,18 +1,21 @@ #ifndef BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ #define BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ -#include "conjugate_updater.h" +#include "semi_conjugate_updater.h" #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/hierarchies/priors/nxig_prior_model.h" -class NNxIGUpdater : public ConjugateUpdater { +class NNxIGUpdater + : public SemiConjugateUpdater { public: NNxIGUpdater() = default; ~NNxIGUpdater() = default; - - bool is_conjugate() const override { return false; }; - void compute_posterior_hypers(AbstractLikelihood& like, - AbstractPriorModel& prior) override; + + ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; + + // void compute_posterior_hypers(AbstractLikelihood& like, + // AbstractPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ diff --git a/src/hierarchies/updaters/semi_conjugate_updater.h b/src/hierarchies/updaters/semi_conjugate_updater.h new file mode 100644 index 000000000..2ced96cf9 --- /dev/null +++ b/src/hierarchies/updaters/semi_conjugate_updater.h @@ -0,0 +1,66 @@ +#ifndef BAYESMIX_HIERARCHIES_UPDATERS_SEMI_CONJUGATE_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_UPDATERS_SEMI_CONJUGATE_UPDATER_H_ + +#include + +#include "abstract_updater.h" +#include "src/hierarchies/likelihoods/abstract_likelihood.h" +#include "src/hierarchies/priors/abstract_prior_model.h" + +template +class SemiConjugateUpdater : public AbstractUpdater { + public: + SemiConjugateUpdater() = default; + + ~SemiConjugateUpdater() = default; + + void draw(AbstractLikelihood& like, AbstractPriorModel& prior, + bool update_params) override; + + void save_posterior_hypers(const ProtoHypers& post_hypers_) override; + + protected: + Likelihood& downcast_likelihood(AbstractLikelihood& like_); + PriorModel& downcast_prior(AbstractPriorModel& prior_); + ProtoHypers post_hypers; +}; + +// Methods' definitions +template +Likelihood& SemiConjugateUpdater::downcast_likelihood( + AbstractLikelihood& like_) { + return static_cast(like_); +} + +template +PriorModel& SemiConjugateUpdater::downcast_prior( + AbstractPriorModel& prior_) { + return static_cast(prior_); +} + +template +void SemiConjugateUpdater::draw( + AbstractLikelihood& like, AbstractPriorModel& prior, bool update_params) { + // Likelihood and PriorModel downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + // Sample from the full conditional of a semi-conjugate hierarchy + bool set_card = true; /*, use_post_hypers=true;*/ + if (likecast.get_card() == 0) { + auto prior_params = *priorcast.get_hypers_proto(); + likecast.set_state_from_proto(*priorcast.sample(prior_params), !set_card); + } else { + auto post_params = compute_posterior_hypers(likecast, priorcast); + likecast.set_state_from_proto(*priorcast.sample(post_params), !set_card); + if (update_params) save_posterior_hypers(post_params); + } +} + +template +void SemiConjugateUpdater::save_posterior_hypers( + const ProtoHypers& post_hypers_) { + post_hypers = post_hypers_; + return; +} + +#endif // BAYESMIX_HIERARCHIES_UPDATERS_SEMI_CONJUGATE_UPDATER_H_ diff --git a/test/likelihoods.cc b/test/likelihoods.cc index bd52f7b26..d0fccc775 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -371,8 +371,8 @@ TEST(laplace_likelihood, eval_lpdf_unconstrained) { clust_state_.mutable_uni_ls_state()->CopyFrom(state_); like->set_state_from_proto(clust_state_); - // Add new datum to likelihood - Eigen::VectorXd data(3); + // Add new data to likelihood + Eigen::MatrixXd data(3, 1); data << 4.5, 5.1, 2.5; double lpdf = 0.0; for (int i = 0; i < data.size(); ++i) { @@ -380,6 +380,7 @@ TEST(laplace_likelihood, eval_lpdf_unconstrained) { lpdf += like->lpdf(data.row(i)); } + like->set_dataset(&data); // Questa cosa è sempre garantita?? double clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); ASSERT_DOUBLE_EQ(lpdf, clus_lpdf); diff --git a/test/prior_models.cc b/test/prior_models.cc index 2374ae73d..7a4b170cf 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -111,7 +111,6 @@ TEST(nig_prior_model, normal_mean_prior) { TEST(nig_prior_model, sample) { // Instance auto prior = std::make_shared(); - bool use_post_hypers = true; // Define prior hypers bayesmix::AlgorithmState::HierarchyHypers hypers_proto; @@ -122,8 +121,8 @@ TEST(nig_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(!use_post_hypers); - auto state2 = prior->sample(!use_post_hypers); + auto state1 = prior->sample(*prior->get_hypers_proto()); + auto state2 = prior->sample(*prior->get_hypers_proto()); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); @@ -193,7 +192,7 @@ TEST(nxig_prior_model, fixed_values_prior) { TEST(nxig_prior_model, sample) { // Instance auto prior = std::make_shared(); - bool use_post_hypers = true; + // bool use_post_hypers = true; // Define prior hypers bayesmix::AlgorithmState::HierarchyHypers hypers_proto; @@ -204,8 +203,8 @@ TEST(nxig_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(!use_post_hypers); - auto state2 = prior->sample(!use_post_hypers); + auto state1 = prior->sample(*prior->get_hypers_proto()); + auto state2 = prior->sample(*prior->get_hypers_proto()); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); @@ -364,7 +363,7 @@ TEST(nw_prior_model, normal_mean_prior) { TEST(nw_prior_model, sample) { // Instance auto prior = std::make_shared(); - bool use_post_hypers = true; + // bool use_post_hypers = true; // Define prior hypers bayesmix::AlgorithmState::HierarchyHypers hypers_proto; @@ -379,8 +378,8 @@ TEST(nw_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(!use_post_hypers); - auto state2 = prior->sample(!use_post_hypers); + auto state1 = prior->sample(*prior->get_hypers_proto()); + auto state2 = prior->sample(*prior->get_hypers_proto()); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); @@ -455,7 +454,7 @@ TEST(mnig_prior_model, fixed_values_prior) { TEST(mnig_prior_model, sample) { // Instance auto prior = std::make_shared(); - bool use_post_hypers = true; + // bool use_post_hypers = true; // Define prior hypers bayesmix::AlgorithmState::HierarchyHypers hypers_proto; @@ -471,8 +470,8 @@ TEST(mnig_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(!use_post_hypers); - auto state2 = prior->sample(!use_post_hypers); + auto state1 = prior->sample(*prior->get_hypers_proto()); + auto state2 = prior->sample(*prior->get_hypers_proto()); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); From 53426186499189e05f3d0d3c3654cd73d210a234 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 14:37:36 +0200 Subject: [PATCH 226/317] ignore .old folders --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 4766a42ce..34b7e5ed9 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ resources/2d .idea/ # Build debug folder cmake-build-debug/ + +# .old folders +src/hierarchies/updaters/.old/ From d2b193477e6bb2338ac0fd7e746b1361bc4a8814 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 14:38:37 +0200 Subject: [PATCH 227/317] Code improvements --- src/hierarchies/CMakeLists.txt | 5 -- src/hierarchies/base_hierarchy.h | 69 +++++++++---------- src/hierarchies/likelihoods/CMakeLists.txt | 1 + src/hierarchies/likelihoods/base_likelihood.h | 54 +-------------- .../likelihoods/likelihood_internal.h | 53 ++++++++++++++ src/hierarchies/lin_reg_uni_hierarchy.h | 5 +- src/hierarchies/nnig_hierarchy.h | 4 +- src/hierarchies/nnw_hierarchy.h | 7 +- src/hierarchies/priors/CMakeLists.txt | 1 + src/hierarchies/priors/abstract_prior_model.h | 7 +- src/hierarchies/priors/base_prior_model.h | 37 +++------- src/hierarchies/priors/fa_prior_model.cc | 4 +- src/hierarchies/priors/fa_prior_model.h | 5 +- src/hierarchies/priors/mnig_prior_model.cc | 6 +- src/hierarchies/priors/mnig_prior_model.h | 5 +- src/hierarchies/priors/nig_prior_model.cc | 5 +- src/hierarchies/priors/nig_prior_model.h | 5 +- src/hierarchies/priors/nw_prior_model.cc | 5 +- src/hierarchies/priors/nw_prior_model.h | 2 +- src/hierarchies/priors/nxig_prior_model.cc | 5 +- src/hierarchies/priors/nxig_prior_model.h | 5 +- src/hierarchies/priors/prior_model_internal.h | 24 +++++++ src/hierarchies/updaters/abstract_updater.h | 34 ++++++--- src/hierarchies/updaters/conjugate_updater.h | 57 --------------- src/hierarchies/updaters/fa_updater.cc | 3 +- src/hierarchies/updaters/mnig_updater.cc | 6 +- src/hierarchies/updaters/mnig_updater.h | 4 +- src/hierarchies/updaters/nnig_updater.cc | 7 +- src/hierarchies/updaters/nnig_updater.h | 4 +- src/hierarchies/updaters/nnw_updater.cc | 8 +-- src/hierarchies/updaters/nnw_updater.h | 4 +- src/hierarchies/updaters/nnxig_updater.cc | 6 +- src/hierarchies/updaters/nnxig_updater.h | 4 +- .../updaters/semi_conjugate_updater.h | 9 ++- test/hierarchies.cc | 1 + test/prior_models.cc | 16 ++--- 36 files changed, 225 insertions(+), 252 deletions(-) create mode 100644 src/hierarchies/likelihoods/likelihood_internal.h create mode 100644 src/hierarchies/priors/prior_model_internal.h delete mode 100644 src/hierarchies/updaters/conjugate_updater.h diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index a66024151..42a45fd87 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -5,14 +5,9 @@ target_sources(bayesmix nnig_hierarchy.h nnxig_hierarchy.h nnw_hierarchy.h - # nnw_hierarchy.cc - # conjugate_hierarchy.h lin_reg_uni_hierarchy.h - # lin_reg_uni_hierarchy.cc fa_hierarchy.h - # fa_hierarchy.cc lapnig_hierarchy.h - # lapnig_hierarchy.cc ) add_subdirectory(likelihoods) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 83505135c..82509b44d 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -42,7 +42,9 @@ class BaseHierarchy : public AbstractHierarchy { std::shared_ptr updater; public: + // Useful type aliases using HyperParams = decltype(prior->get_hypers()); + using ProtoHypersPtr = AbstractUpdater::ProtoHypersPtr; using ProtoHypers = AbstractUpdater::ProtoHypers; //! Constructor that allows the specification of Likelihood, PriorModel and @@ -110,26 +112,6 @@ class BaseHierarchy : public AbstractHierarchy { return out; } - //! Returns an independent, data-less copy of this object - // std::shared_ptr deep_clone() const override { - // auto out = std::make_shared(static_cast(*this)); - - // out->clear_data(); - // out->clear_summary_statistics(); - - // out->create_empty_prior(); - // std::shared_ptr new_prior(prior->New()); - // new_prior->CopyFrom(*prior.get()); - // out->get_mutable_prior()->CopyFrom(*new_prior.get()); - - // out->create_empty_hypers(); - // auto curr_hypers_proto = get_hypers_proto(); - // out->set_hypers_from_proto(*curr_hypers_proto.get()); - // out->initialize(); - // return out; - // } - //! Public wrapper for `like_lpdf()` methods double get_like_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = @@ -147,19 +129,17 @@ class BaseHierarchy : public AbstractHierarchy { return like->lpdf_grid(data, covariates); }; - // ADD EXCEPTION HANDLING //! Public wrapper for `marg_lpdf()` methods double get_marg_lpdf( - const ProtoHypers &hier_params, const Eigen::RowVectorXd &datum, + ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate /*= Eigen::RowVectorXd(0)*/) const { - if (this->is_dependent()) { + if (this->is_dependent() and covariate.size() != 0) { return marg_lpdf(hier_params, datum, covariate); } else { return marg_lpdf(hier_params, datum); } } - // ADD EXCEPTION HANDLING //! Evaluates the log-prior predictive distribution of data in a single point //! @param datum Point which is to be evaluated //! @param covariate (Optional) covariate vector associated to datum @@ -167,10 +147,9 @@ class BaseHierarchy : public AbstractHierarchy { double prior_pred_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const override { - return get_marg_lpdf(*(prior->get_hypers_proto()), datum, covariate); + return get_marg_lpdf(prior->get_hypers_proto(), datum, covariate); } - // ADD EXCEPTION HANDLING //! Evaluates the log-prior predictive distr. of data in a grid of points //! @param data Grid of points (by row) which are to be evaluated //! @param covariates (Optional) covariate vectors associated to data @@ -202,7 +181,6 @@ class BaseHierarchy : public AbstractHierarchy { return lpdf; } - // ADD EXCEPTION HANDLING //! Evaluates the log-conditional predictive distr. of data in a single point //! @param datum Point which is to be evaluated //! @param covariate (Optional) covariate vector associated to datum @@ -214,7 +192,6 @@ class BaseHierarchy : public AbstractHierarchy { datum, covariate); } - // ADD EXCEPTION HANDLING //! Evaluates the log-prior predictive distr. of data in a grid of points //! @param data Grid of points (by row) which are to be evaluated //! @param covariates (Optional) covariate vectors associated to data @@ -248,8 +225,8 @@ class BaseHierarchy : public AbstractHierarchy { //! Generates new state values from the centering prior distribution void sample_prior() override { - auto hypers = prior->get_hypers_proto(); - like->set_state_from_proto(*prior->sample(*hypers), false); + // auto hypers = prior->get_hypers_proto(); + like->set_state_from_proto(*prior->sample(/*hypers*/), false); }; //! Generates new state values from the centering posterior distribution @@ -369,7 +346,7 @@ class BaseHierarchy : public AbstractHierarchy { void initialize() override { prior->initialize(); if (is_conjugate()) { - updater->save_posterior_hypers(*prior->get_hypers_proto()); + updater->save_posterior_hypers(prior->get_hypers_proto()); } initialize_state(); like->clear_data(); @@ -394,12 +371,11 @@ class BaseHierarchy : public AbstractHierarchy { //! Initializes state parameters to appropriate values virtual void initialize_state() = 0; - // ADD EXEPTION HANDLING FOR is_dependent()? //! Evaluates the log-marginal distribution of data in a single point //! @param params Container of (prior or posterior) hyperparameter values //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf - virtual double marg_lpdf(const ProtoHypers &hier_params, + virtual double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const { if (!is_conjugate()) { throw std::runtime_error( @@ -410,13 +386,12 @@ class BaseHierarchy : public AbstractHierarchy { } } - // ADD EXEPTION HANDLING FOR is_dependent()? //! Evaluates the log-marginal distribution of data in a single point //! @param params Container of (prior or posterior) hyperparameter values //! @param datum Point which is to be evaluated //! @param covariate Covariate vector associated to datum //! @return The evaluation of the lpdf - virtual double marg_lpdf(const ProtoHypers &hier_params, + virtual double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { if (!is_conjugate()) { @@ -427,11 +402,29 @@ class BaseHierarchy : public AbstractHierarchy { "marg_lpdf() not implemented for this hierarchy"); } } - - // TEMPORANEO! - // const Eigen::MatrixXd *dataset_ptr; }; +// OLD STUFF +//! Returns an independent, data-less copy of this object +// std::shared_ptr deep_clone() const override { +// auto out = std::make_shared(static_cast(*this)); + +// out->clear_data(); +// out->clear_summary_statistics(); + +// out->create_empty_prior(); +// std::shared_ptr new_prior(prior->New()); +// new_prior->CopyFrom(*prior.get()); +// out->get_mutable_prior()->CopyFrom(*new_prior.get()); + +// out->create_empty_hypers(); +// auto curr_hypers_proto = get_hypers_proto(); +// out->set_hypers_from_proto(*curr_hypers_proto.get()); +// out->initialize(); +// return out; +// } + // TODO: Move definitions outside the class to improve code cleaness #endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 197d4106b..df10e8674 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -1,4 +1,5 @@ target_sources(bayesmix PUBLIC + likelihood_internal.h abstract_likelihood.h base_likelihood.h uni_norm_likelihood.h diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 369334400..aa0f4f77f 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -10,55 +10,7 @@ #include "abstract_likelihood.h" #include "algorithm_state.pb.h" - -namespace internal { - -/* SFINAE for cluster_lpdf_from_unconstrained() */ -template -auto cluster_lpdf_from_unconstrained( - const Like &like, Eigen::Matrix unconstrained_params, - int) - -> decltype(like.template cluster_lpdf_from_unconstrained( - unconstrained_params)) { - return like.template cluster_lpdf_from_unconstrained( - unconstrained_params); -} -template -auto cluster_lpdf_from_unconstrained( - const Like &like, Eigen::Matrix unconstrained_params, - double) -> T { - throw(std::runtime_error( - "cluster_lpdf_from_unconstrained() not yet implemented")); -} - -/* SFINAE for get_unconstrained_state() */ -template -auto get_unconstrained_state(const State &state, int) - -> decltype(state.get_unconstrained()) { - return state.get_unconstrained(); -} -template -auto get_unconstrained_state(const State &state, double) -> Eigen::VectorXd { - throw(std::runtime_error("get_unconstrained_state() not yet implemented")); -} - -/* SFINAE for set_state_from_unconstrained() */ -template -auto set_state_from_unconstrained(State &state, - const Eigen::VectorXd &unconstrained_state, - int) - -> decltype(state.set_from_unconstrained(unconstrained_state)) { - state.set_from_unconstrained(unconstrained_state); -} -template -auto set_state_from_unconstrained(State &state, - const Eigen::VectorXd &unconstrained_state, - double) -> void { - throw(std::runtime_error( - "set_state_from_unconstrained() not yet implemented")); -} - -} // namespace internal +#include "likelihood_internal.h" template class BaseLikelihood : public AbstractLikelihood { @@ -82,7 +34,7 @@ class BaseLikelihood : public AbstractLikelihood { //! By unconstrained parameters we mean that each entry of //! the parameter vector can range over (-inf, inf). //! Usually, some kind of transformation is required from the unconstrained - //! parameterization to the actual parameterization. + //! parametrization to the actual one. //! @param unconstrained_params vector collecting the unconstrained //! parameters //! @return The evaluation of the log likelihood over all data in the cluster @@ -97,7 +49,7 @@ class BaseLikelihood : public AbstractLikelihood { //! cluster given unconstrained parameter values. By unconstrained parameters //! we mean that each entry of the parameter vector can range over (-inf, //! inf). Usually, some kind of transformation is required from the - //! unconstrained parameterization to the actual parameterization. + //! unconstrained parametrization to the actual one. //! @param unconstrained_params vector collecting the unconstrained //! parameters //! @return The evaluation of the log likelihood over all data in the cluster diff --git a/src/hierarchies/likelihoods/likelihood_internal.h b/src/hierarchies/likelihoods/likelihood_internal.h new file mode 100644 index 000000000..28f0be6ed --- /dev/null +++ b/src/hierarchies/likelihoods/likelihood_internal.h @@ -0,0 +1,53 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_LIKELIHOOD_INTERNAL_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_LIKELIHOOD_INTERNAL_H_ + +namespace internal { + +/* SFINAE for cluster_lpdf_from_unconstrained() */ +template +auto cluster_lpdf_from_unconstrained( + const Like &like, Eigen::Matrix unconstrained_params, + int) + -> decltype(like.template cluster_lpdf_from_unconstrained( + unconstrained_params)) { + return like.template cluster_lpdf_from_unconstrained( + unconstrained_params); +} +template +auto cluster_lpdf_from_unconstrained( + const Like &like, Eigen::Matrix unconstrained_params, + double) -> T { + throw(std::runtime_error( + "cluster_lpdf_from_unconstrained() not yet implemented")); +} + +/* SFINAE for get_unconstrained_state() */ +template +auto get_unconstrained_state(const State &state, int) + -> decltype(state.get_unconstrained()) { + return state.get_unconstrained(); +} +template +auto get_unconstrained_state(const State &state, double) -> Eigen::VectorXd { + throw(std::runtime_error("get_unconstrained_state() not yet implemented")); +} + +/* SFINAE for set_state_from_unconstrained() */ +template +auto set_state_from_unconstrained(State &state, + const Eigen::VectorXd &unconstrained_state, + int) + -> decltype(state.set_from_unconstrained(unconstrained_state)) { + state.set_from_unconstrained(unconstrained_state); +} +template +auto set_state_from_unconstrained(State &state, + const Eigen::VectorXd &unconstrained_state, + double) -> void { + throw(std::runtime_error( + "set_state_from_unconstrained() not yet implemented")); +} + +} // namespace internal + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_LIKELIHOOD_INTERNAL_H_ diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index b48efaa71..1d0fa1d01 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -44,10 +44,9 @@ class LinRegUniHierarchy like->set_state(state); }; - double marg_lpdf(const ProtoHypers &hier_params, - const Eigen::RowVectorXd &datum, + double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const override { - auto params = hier_params.lin_reg_uni_state(); + auto params = hier_params->lin_reg_uni_state(); Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); Eigen::MatrixXd var_scaling = bayesmix::to_eigen(params.var_scaling()); diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index 294f1ee12..2aa273855 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -41,9 +41,9 @@ class NNIGHierarchy like->set_state(state); }; - double marg_lpdf(const ProtoHypers &hier_params, + double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { - auto params = hier_params.nnig_state(); + auto params = hier_params->nnig_state(); double sig_n = sqrt(params.scale() * (params.var_scaling() + 1) / (params.shape() * params.var_scaling())); return stan::math::student_t_lpdf(datum(0), 2 * params.shape(), diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 9c8c7c565..4073ef053 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -44,7 +44,7 @@ class NNWHierarchy like->set_state(state); }; - double marg_lpdf(const ProtoHypers &hier_params, + double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { HyperParams pred_params = get_predictive_t_parameters(hier_params); Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); @@ -54,9 +54,8 @@ class NNWHierarchy logdet); } - HyperParams get_predictive_t_parameters( - const ProtoHypers &hier_params) const { - auto params = hier_params.nnw_state(); + HyperParams get_predictive_t_parameters(ProtoHypersPtr hier_params) const { + auto params = hier_params->nnw_state(); // Compute dof and scale of marginal distribution unsigned int dim = like->get_dim(); double nu_n = params.deg_free() - dim + 1; diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt index ddbf6e4da..d6901ee65 100644 --- a/src/hierarchies/priors/CMakeLists.txt +++ b/src/hierarchies/priors/CMakeLists.txt @@ -1,4 +1,5 @@ target_sources(bayesmix PUBLIC + prior_model_internal.h abstract_prior_model.h base_prior_model.h hyperparams.h diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index d1f839ad3..71a75d402 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -13,6 +13,11 @@ class AbstractPriorModel { public: + // Useful type aliases + using ProtoHypersPtr = + std::shared_ptr; + using ProtoHypers = ProtoHypersPtr::element_type; + //! Default destructor virtual ~AbstractPriorModel() = default; @@ -67,7 +72,7 @@ class AbstractPriorModel { // bool use_post_hypers) = 0; virtual std::shared_ptr sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) = 0; + ProtoHypersPtr hier_hypers = nullptr) = 0; //! Updates hyperparameter values given a vector of cluster states virtual void update_hypers( diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 47403f2f5..04e14dc50 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -12,28 +12,9 @@ #include "abstract_prior_model.h" #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" +#include "prior_model_internal.h" #include "src/utils/rng.h" -namespace internal { - -template -auto lpdf_from_unconstrained( - const Prior &prior, - Eigen::Matrix unconstrained_params, int) - -> decltype(prior.template lpdf_from_unconstrained( - unconstrained_params)) { - return prior.template lpdf_from_unconstrained(unconstrained_params); -} - -template -auto lpdf_from_unconstrained( - const Prior &prior, - Eigen::Matrix unconstrained_params, double) -> T { - throw(std::runtime_error("lpdf_from_unconstrained() not yet implemented")); -} - -} // namespace internal - template class BasePriorModel : public AbstractPriorModel { public: @@ -55,7 +36,7 @@ class BasePriorModel : public AbstractPriorModel { Eigen::VectorXd unconstrained_params) const override { return internal::lpdf_from_unconstrained( static_cast(*this), unconstrained_params, 0); - } + }; //! This version using `stan::math::var` type is required for Stan automatic //! aifferentiation. Evaluates the log likelihood for unconstrained parameter @@ -71,7 +52,7 @@ class BasePriorModel : public AbstractPriorModel { const override { return internal::lpdf_from_unconstrained( static_cast(*this), unconstrained_params, 0); - } + }; //! Returns an independent, data-less copy of this object std::shared_ptr clone() const override; @@ -83,7 +64,7 @@ class BasePriorModel : public AbstractPriorModel { google::protobuf::Message *get_mutable_prior() override; //! Returns the struct of the current prior hyperparameters - HyperParams get_hypers() const { return *hypers; } + HyperParams get_hypers() const { return *hypers; }; //! Returns the struct of the current posterior hyperparameters // HyperParams get_posterior_hypers() const { return post_hypers; } @@ -105,32 +86,32 @@ class BasePriorModel : public AbstractPriorModel { void check_prior_is_set() const; //! Re-initializes the prior of the hierarchy to a newly created object - void create_empty_prior() { prior.reset(new Prior); } + void create_empty_prior() { prior.reset(new Prior); }; //! Re-initializes the hyperparameters of the hierarchy to a newly created //! object - void create_empty_hypers() { hypers.reset(new HyperParams); } + void create_empty_hypers() { hypers.reset(new HyperParams); }; //! Down-casts the given generic proto message to a HierarchyHypers proto bayesmix::AlgorithmState::HierarchyHypers *downcast_hypers( google::protobuf::Message *state_) const { return google::protobuf::internal::down_cast< bayesmix::AlgorithmState::HierarchyHypers *>(state_); - } + }; //! Down-casts the given generic proto message to a HierarchyHypers proto const bayesmix::AlgorithmState::HierarchyHypers &downcast_hypers( const google::protobuf::Message &state_) const { return google::protobuf::internal::down_cast< const bayesmix::AlgorithmState::HierarchyHypers &>(state_); - } + }; //! Down-casts the given generic proto message to a ClusterState proto const bayesmix::AlgorithmState::ClusterState &downcast_state( const google::protobuf::Message &state_) const { return google::protobuf::internal::down_cast< const bayesmix::AlgorithmState::ClusterState &>(state_); - } + }; //! Container for prior hyperparameters values std::shared_ptr hypers = std::make_shared(); diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index 4b644d860..c8ca7a95c 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -30,12 +30,12 @@ double FAPriorModel::lpdf(const google::protobuf::Message &state_) { } std::shared_ptr FAPriorModel::sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { + ProtoHypersPtr hier_hypers) { // Random seed auto &rng = bayesmix::Rng::Instance().get(); // Get params to use - auto params = hier_hypers.fa_state(); + auto params = get_hypers_proto()->fa_state(); Eigen::VectorXd mutilde = bayesmix::to_eigen(params.mutilde()); Eigen::VectorXd beta = bayesmix::to_eigen(params.beta()); diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index c8e7ab0fd..3544e9e30 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -16,13 +16,16 @@ class FAPriorModel : public BasePriorModel { public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + FAPriorModel() = default; ~FAPriorModel() = default; double lpdf(const google::protobuf::Message &state_) override; std::shared_ptr sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; + ProtoHypersPtr hier_hypers = nullptr) override; // std::shared_ptr sample( // bool use_post_hypers) override; diff --git a/src/hierarchies/priors/mnig_prior_model.cc b/src/hierarchies/priors/mnig_prior_model.cc index e6d036674..af629df99 100644 --- a/src/hierarchies/priors/mnig_prior_model.cc +++ b/src/hierarchies/priors/mnig_prior_model.cc @@ -12,10 +12,10 @@ double MNIGPriorModel::lpdf(const google::protobuf::Message &state_) { } std::shared_ptr MNIGPriorModel::sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { + ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - - auto params = hier_hypers.lin_reg_uni_state(); + auto params = (hier_hypers) ? hier_hypers->lin_reg_uni_state() + : get_hypers_proto()->lin_reg_uni_state(); Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); Eigen::MatrixXd var_scaling = bayesmix::to_eigen(params.var_scaling()); diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index b307675a0..4f2cf478d 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -16,13 +16,16 @@ class MNIGPriorModel : public BasePriorModel { public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + MNIGPriorModel() = default; ~MNIGPriorModel() = default; double lpdf(const google::protobuf::Message &state_) override; std::shared_ptr sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; + ProtoHypersPtr hier_hypers = nullptr) override; // std::shared_ptr sample( // bool use_post_hypers) override; diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index 5cb797c08..7690ab607 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -100,9 +100,10 @@ double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { // }; std::shared_ptr NIGPriorModel::sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { + ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - auto params = hier_hypers.nnig_state(); + auto params = (hier_hypers) ? hier_hypers->nnig_state() + : get_hypers_proto()->nnig_state(); // Hyperparams::NIG params = use_post_hypers ? post_hypers : *hypers; double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index e791eaa11..cb0362c81 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -17,6 +17,9 @@ class NIGPriorModel : public BasePriorModel { public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + NIGPriorModel() = default; ~NIGPriorModel() = default; @@ -41,7 +44,7 @@ class NIGPriorModel : public BasePriorModel sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; + ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override; diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc index ebdf73910..256bb8e1a 100644 --- a/src/hierarchies/priors/nw_prior_model.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -122,9 +122,10 @@ double NWPriorModel::lpdf(const google::protobuf::Message &state_) { } std::shared_ptr NWPriorModel::sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { + ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - auto params = hier_hypers.nnw_state(); + auto params = (hier_hypers) ? hier_hypers->nnw_state() + : get_hypers_proto()->nnw_state(); Eigen::MatrixXd scale = bayesmix::to_eigen(params.scale()); Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h index cb393ee7f..cb5b695ba 100644 --- a/src/hierarchies/priors/nw_prior_model.h +++ b/src/hierarchies/priors/nw_prior_model.h @@ -24,7 +24,7 @@ class NWPriorModel : public BasePriorModel sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; + ProtoHypersPtr hier_hypers = nullptr) override; // std::shared_ptr sample( // bool use_post_hypers) override; diff --git a/src/hierarchies/priors/nxig_prior_model.cc b/src/hierarchies/priors/nxig_prior_model.cc index 5c7a953aa..f2704ca1e 100644 --- a/src/hierarchies/priors/nxig_prior_model.cc +++ b/src/hierarchies/priors/nxig_prior_model.cc @@ -31,9 +31,10 @@ double NxIGPriorModel::lpdf(const google::protobuf::Message &state_) { } std::shared_ptr NxIGPriorModel::sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) { + ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); - auto params = hier_hypers.nnxig_state(); + auto params = (hier_hypers) ? hier_hypers->nnxig_state() + : get_hypers_proto()->nnxig_state(); double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); double mean = stan::math::normal_rng(params.mean(), sqrt(params.var()), rng); diff --git a/src/hierarchies/priors/nxig_prior_model.h b/src/hierarchies/priors/nxig_prior_model.h index bb5f33052..bfc264736 100644 --- a/src/hierarchies/priors/nxig_prior_model.h +++ b/src/hierarchies/priors/nxig_prior_model.h @@ -18,6 +18,9 @@ class NxIGPriorModel : public BasePriorModel { public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + NxIGPriorModel() = default; ~NxIGPriorModel() = default; @@ -27,7 +30,7 @@ class NxIGPriorModel : public BasePriorModel sample( - bayesmix::AlgorithmState::HierarchyHypers hier_hypers) override; + ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override; diff --git a/src/hierarchies/priors/prior_model_internal.h b/src/hierarchies/priors/prior_model_internal.h new file mode 100644 index 000000000..5272038ea --- /dev/null +++ b/src/hierarchies/priors/prior_model_internal.h @@ -0,0 +1,24 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_PRIOR_MODEL_INTERNAL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_PRIOR_MODEL_INTERNAL_H_ + +namespace internal { + +template +auto lpdf_from_unconstrained( + const Prior &prior, + Eigen::Matrix unconstrained_params, int) + -> decltype(prior.template lpdf_from_unconstrained( + unconstrained_params)) { + return prior.template lpdf_from_unconstrained(unconstrained_params); +} + +template +auto lpdf_from_unconstrained( + const Prior &prior, + Eigen::Matrix unconstrained_params, double) -> T { + throw(std::runtime_error("lpdf_from_unconstrained() not yet implemented")); +} + +} // namespace internal + +#endif // BAYESMIX_HIERARCHIES_PRIORS_PRIOR_MODEL_INTERNAL_H_ diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 0983b0146..2a4c32b91 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -7,13 +7,15 @@ class AbstractUpdater { public: - // Type alias - using ProtoHypers = bayesmix::AlgorithmState::HierarchyHypers; + // Type aliases + using ProtoHypersPtr = + std::shared_ptr; + using ProtoHypers = ProtoHypersPtr::element_type; //! Default destructor virtual ~AbstractUpdater() = default; - //! Returns whether the current updater is for conjugate model or not + //! Returns whether the current updater is for a (semi)conjugate model or not virtual bool is_conjugate() const { return false; }; //! Sampling from the full conditional, given the likelihood and the prior @@ -26,16 +28,28 @@ class AbstractUpdater { //! Computes the posterior hyperparameters required for the sampling in case //! of conjugate hierarchies - virtual ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, - AbstractPriorModel &prior) { - throw(std::runtime_error( - "compute_posterior_hypers() not implemented for this updater")); + virtual ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) { + if (!is_conjugate()) { + throw( + std::runtime_error("Cannot call compute_posterior_hypers() from a " + "non-(semi)conjugate updater")); + } else { + throw(std::runtime_error( + "compute_posterior_hypers() not implemented for this updater")); + } } //! Stores the posterior hyperparameters in an appropriate container - virtual void save_posterior_hypers(const ProtoHypers &post_hypers_) { - throw(std::runtime_error( - "save_posterior_hypers() not implemented for this updater")); + virtual void save_posterior_hypers(ProtoHypersPtr post_hypers_) { + if (!is_conjugate()) { + throw( + std::runtime_error("Cannot call save_posterior_hypers() from a " + "non-(semi)conjugate updater")); + } else { + throw(std::runtime_error( + "save_posterior_hypers() not implemented for this updater")); + } } }; diff --git a/src/hierarchies/updaters/conjugate_updater.h b/src/hierarchies/updaters/conjugate_updater.h deleted file mode 100644 index eb92e9ce1..000000000 --- a/src/hierarchies/updaters/conjugate_updater.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ -#define BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ - -#include "abstract_updater.h" -#include "src/hierarchies/likelihoods/abstract_likelihood.h" -#include "src/hierarchies/priors/abstract_prior_model.h" - -template -class ConjugateUpdater : public AbstractUpdater { - public: - ConjugateUpdater() = default; - ~ConjugateUpdater() = default; - - bool is_conjugate() const override { return true; }; - void draw(AbstractLikelihood& like, AbstractPriorModel& prior, - bool update_params) override; - - protected: - Likelihood& downcast_likelihood(AbstractLikelihood& like_); - PriorModel& downcast_prior(AbstractPriorModel& prior_); -}; - -// Methods' definitions -template -Likelihood& ConjugateUpdater::downcast_likelihood( - AbstractLikelihood& like_) { - return static_cast(like_); -} - -template -PriorModel& ConjugateUpdater::downcast_prior( - AbstractPriorModel& prior_) { - return static_cast(prior_); -} - -template -void ConjugateUpdater::draw(AbstractLikelihood& like, - AbstractPriorModel& prior, - bool update_params) { - // Likelihood and PriorModel downcast - auto& likecast = downcast_likelihood(like); - auto& priorcast = downcast_prior(prior); - - // Sample from the full conditional of a conjugate hierarchy - bool set_card = true, use_post_hypers=true; - if (likecast.get_card() == 0) { - likecast.set_state_from_proto(*priorcast.sample(!use_post_hypers), !set_card); - } else { - auto prev_hypers = priorcast.get_posterior_hypers(); - compute_posterior_hypers(likecast, priorcast); - likecast.set_state_from_proto(*priorcast.sample(use_post_hypers), !set_card); - if (!update_params) - priorcast.set_posterior_hypers(prev_hypers); - } -} - -#endif // BAYESMIX_HIERARCHIES_UPDATERS_CONJUGATE_UPDATER_H_ diff --git a/src/hierarchies/updaters/fa_updater.cc b/src/hierarchies/updaters/fa_updater.cc index 8fb9f22b5..e88208b44 100644 --- a/src/hierarchies/updaters/fa_updater.cc +++ b/src/hierarchies/updaters/fa_updater.cc @@ -10,8 +10,7 @@ void FAUpdater::draw(AbstractLikelihood& like, AbstractPriorModel& prior, // Sample from the full conditional of the fa hierarchy bool set_card = true, use_post_hypers = true; if (likecast.get_card() == 0) { - auto prior_params = *(priorcast.get_hypers_proto()); - likecast.set_state_from_proto(*priorcast.sample(prior_params), !set_card); + likecast.set_state_from_proto(*priorcast.sample(), !set_card); } else { // Get state and hypers State::FA new_state = likecast.get_state(); diff --git a/src/hierarchies/updaters/mnig_updater.cc b/src/hierarchies/updaters/mnig_updater.cc index 23d7a8627..d33ffed90 100644 --- a/src/hierarchies/updaters/mnig_updater.cc +++ b/src/hierarchies/updaters/mnig_updater.cc @@ -1,6 +1,6 @@ #include "mnig_updater.h" -AbstractUpdater::ProtoHypers MNIGUpdater::compute_posterior_hypers( +AbstractUpdater::ProtoHypersPtr MNIGUpdater::compute_posterior_hypers( AbstractLikelihood& like, AbstractPriorModel& prior) { // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); @@ -16,7 +16,7 @@ AbstractUpdater::ProtoHypers MNIGUpdater::compute_posterior_hypers( // No update possible if (card == 0) { - return *(priorcast.get_hypers_proto()); + return priorcast.get_hypers_proto(); } // Compute posterior hyperparameters @@ -41,7 +41,7 @@ AbstractUpdater::ProtoHypers MNIGUpdater::compute_posterior_hypers( out.mutable_lin_reg_uni_state()->mutable_var_scaling()); out.mutable_lin_reg_uni_state()->set_shape(shape); out.mutable_lin_reg_uni_state()->set_scale(scale); - return out; + return std::make_shared(out); } // void MNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, diff --git a/src/hierarchies/updaters/mnig_updater.h b/src/hierarchies/updaters/mnig_updater.h index fd1c17214..fbb1fe5ab 100644 --- a/src/hierarchies/updaters/mnig_updater.h +++ b/src/hierarchies/updaters/mnig_updater.h @@ -13,8 +13,8 @@ class MNIGUpdater bool is_conjugate() const override { return true; }; - ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, - AbstractPriorModel &prior) override; + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; // void compute_posterior_hypers(AbstractLikelihood& like, // AbstractPriorModel& prior) override; diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index f06c1ba5d..d09cd804e 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -3,7 +3,7 @@ #include "src/hierarchies/likelihoods/states/includes.h" #include "src/hierarchies/priors/hyperparams.h" -AbstractUpdater::ProtoHypers NNIGUpdater::compute_posterior_hypers( +AbstractUpdater::ProtoHypersPtr NNIGUpdater::compute_posterior_hypers( AbstractLikelihood& like, AbstractPriorModel& prior) { // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); @@ -17,7 +17,7 @@ AbstractUpdater::ProtoHypers NNIGUpdater::compute_posterior_hypers( // No update possible if (card == 0) { - return *(priorcast.get_hypers_proto()); + return priorcast.get_hypers_proto(); } // Compute posterior hyperparameters @@ -39,8 +39,7 @@ AbstractUpdater::ProtoHypers NNIGUpdater::compute_posterior_hypers( out.mutable_nnig_state()->set_var_scaling(var_scaling); out.mutable_nnig_state()->set_shape(shape); out.mutable_nnig_state()->set_scale(scale); - // priorcast.set_posterior_hypers(post_params); - return out; + return std::make_shared(out); } // void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index cc9e18bfe..7864f63ce 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -13,8 +13,8 @@ class NNIGUpdater bool is_conjugate() const override { return true; }; - ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, - AbstractPriorModel &prior) override; + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; // void compute_posterior_hypers(AbstractLikelihood& like, // AbstractPriorModel& prior) override; diff --git a/src/hierarchies/updaters/nnw_updater.cc b/src/hierarchies/updaters/nnw_updater.cc index bf8b77d62..455f4731a 100644 --- a/src/hierarchies/updaters/nnw_updater.cc +++ b/src/hierarchies/updaters/nnw_updater.cc @@ -5,7 +5,7 @@ #include "src/hierarchies/priors/hyperparams.h" #include "src/utils/proto_utils.h" -AbstractUpdater::ProtoHypers NNWUpdater::compute_posterior_hypers( +AbstractUpdater::ProtoHypersPtr NNWUpdater::compute_posterior_hypers( AbstractLikelihood& like, AbstractPriorModel& prior) { // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); @@ -19,7 +19,7 @@ AbstractUpdater::ProtoHypers NNWUpdater::compute_posterior_hypers( // No update possible if (card == 0) { - return *(prior.get_hypers_proto()); + return prior.get_hypers_proto(); } // Compute posterior hyperparameters @@ -41,12 +41,12 @@ AbstractUpdater::ProtoHypers NNWUpdater::compute_posterior_hypers( // scale_chol = Eigen::LLT(scale).matrixU(); // Proto conversion - bayesmix::AlgorithmState::HierarchyHypers out; + ProtoHypers out; bayesmix::to_proto(mean, out.mutable_nnw_state()->mutable_mean()); out.mutable_nnw_state()->set_var_scaling(var_scaling); out.mutable_nnw_state()->set_deg_free(deg_free); bayesmix::to_proto(scale, out.mutable_nnw_state()->mutable_scale()); - return out; + return std::make_shared(out); } // void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, diff --git a/src/hierarchies/updaters/nnw_updater.h b/src/hierarchies/updaters/nnw_updater.h index 357022e23..a99567a2d 100644 --- a/src/hierarchies/updaters/nnw_updater.h +++ b/src/hierarchies/updaters/nnw_updater.h @@ -13,8 +13,8 @@ class NNWUpdater bool is_conjugate() const override { return true; }; - ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, - AbstractPriorModel &prior) override; + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; // void compute_posterior_hypers(AbstractLikelihood& like, // AbstractPriorModel& prior) override; diff --git a/src/hierarchies/updaters/nnxig_updater.cc b/src/hierarchies/updaters/nnxig_updater.cc index a4fa33ab7..b950f3e7c 100644 --- a/src/hierarchies/updaters/nnxig_updater.cc +++ b/src/hierarchies/updaters/nnxig_updater.cc @@ -3,7 +3,7 @@ #include "src/hierarchies/likelihoods/states/includes.h" #include "src/hierarchies/priors/hyperparams.h" -AbstractUpdater::ProtoHypers NNxIGUpdater::compute_posterior_hypers( +AbstractUpdater::ProtoHypersPtr NNxIGUpdater::compute_posterior_hypers( AbstractLikelihood& like, AbstractPriorModel& prior) { // Likelihood and Prior downcast auto& likecast = downcast_likelihood(like); @@ -18,7 +18,7 @@ AbstractUpdater::ProtoHypers NNxIGUpdater::compute_posterior_hypers( // No update possible if (card == 0) { - return *(priorcast.get_hypers_proto()); + return priorcast.get_hypers_proto(); } // Compute posterior hyperparameters @@ -37,7 +37,7 @@ AbstractUpdater::ProtoHypers NNxIGUpdater::compute_posterior_hypers( out.mutable_nnxig_state()->set_var(var); out.mutable_nnxig_state()->set_shape(shape); out.mutable_nnxig_state()->set_scale(scale); - return out; + return std::make_shared(out); } // void NNxIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, diff --git a/src/hierarchies/updaters/nnxig_updater.h b/src/hierarchies/updaters/nnxig_updater.h index 3200ee2df..a595e544b 100644 --- a/src/hierarchies/updaters/nnxig_updater.h +++ b/src/hierarchies/updaters/nnxig_updater.h @@ -11,8 +11,8 @@ class NNxIGUpdater NNxIGUpdater() = default; ~NNxIGUpdater() = default; - ProtoHypers compute_posterior_hypers(AbstractLikelihood &like, - AbstractPriorModel &prior) override; + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, + AbstractPriorModel &prior) override; // void compute_posterior_hypers(AbstractLikelihood& like, // AbstractPriorModel& prior) override; diff --git a/src/hierarchies/updaters/semi_conjugate_updater.h b/src/hierarchies/updaters/semi_conjugate_updater.h index 2ced96cf9..18fa0a63f 100644 --- a/src/hierarchies/updaters/semi_conjugate_updater.h +++ b/src/hierarchies/updaters/semi_conjugate_updater.h @@ -17,12 +17,12 @@ class SemiConjugateUpdater : public AbstractUpdater { void draw(AbstractLikelihood& like, AbstractPriorModel& prior, bool update_params) override; - void save_posterior_hypers(const ProtoHypers& post_hypers_) override; + void save_posterior_hypers(ProtoHypersPtr post_hypers_) override; protected: Likelihood& downcast_likelihood(AbstractLikelihood& like_); PriorModel& downcast_prior(AbstractPriorModel& prior_); - ProtoHypers post_hypers; + ProtoHypersPtr post_hypers = std::make_shared(); }; // Methods' definitions @@ -47,8 +47,7 @@ void SemiConjugateUpdater::draw( // Sample from the full conditional of a semi-conjugate hierarchy bool set_card = true; /*, use_post_hypers=true;*/ if (likecast.get_card() == 0) { - auto prior_params = *priorcast.get_hypers_proto(); - likecast.set_state_from_proto(*priorcast.sample(prior_params), !set_card); + likecast.set_state_from_proto(*priorcast.sample(), !set_card); } else { auto post_params = compute_posterior_hypers(likecast, priorcast); likecast.set_state_from_proto(*priorcast.sample(post_params), !set_card); @@ -58,7 +57,7 @@ void SemiConjugateUpdater::draw( template void SemiConjugateUpdater::save_posterior_hypers( - const ProtoHypers& post_hypers_) { + ProtoHypersPtr post_hypers_) { post_hypers = post_hypers_; return; } diff --git a/test/hierarchies.cc b/test/hierarchies.cc index ed9b09e63..db3090f0c 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -9,6 +9,7 @@ #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" #include "src/hierarchies/nnw_hierarchy.h" +#include "src/hierarchies/nnxig_hierarchy.h" #include "src/includes.h" #include "src/utils/proto_utils.h" #include "src/utils/rng.h" diff --git a/test/prior_models.cc b/test/prior_models.cc index 7a4b170cf..20bff0d6c 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -121,8 +121,8 @@ TEST(nig_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(*prior->get_hypers_proto()); - auto state2 = prior->sample(*prior->get_hypers_proto()); + auto state1 = prior->sample(); + auto state2 = prior->sample(); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); @@ -203,8 +203,8 @@ TEST(nxig_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(*prior->get_hypers_proto()); - auto state2 = prior->sample(*prior->get_hypers_proto()); + auto state1 = prior->sample(); + auto state2 = prior->sample(); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); @@ -378,8 +378,8 @@ TEST(nw_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(*prior->get_hypers_proto()); - auto state2 = prior->sample(*prior->get_hypers_proto()); + auto state1 = prior->sample(); + auto state2 = prior->sample(); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); @@ -470,8 +470,8 @@ TEST(mnig_prior_model, sample) { // Set hypers and get sampled state as proto prior->set_hypers_from_proto(hypers_proto); - auto state1 = prior->sample(*prior->get_hypers_proto()); - auto state2 = prior->sample(*prior->get_hypers_proto()); + auto state1 = prior->sample(); + auto state2 = prior->sample(); // Check if they coincides ASSERT_TRUE(state1->DebugString() != state2->DebugString()); From dc062cd106f9d15c46d80e5e5edb40fd41f718e6 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 14:49:50 +0200 Subject: [PATCH 228/317] clean docs --- src/hierarchies/abstract_hierarchy.h | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 7e141ae68..7b998010d 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -51,25 +51,20 @@ class AbstractHierarchy { public: - // Set the likelihood for the current hierarchy. Implemented in the - // BaseHierarchy class + // Set the likelihood for the current hierarchy. // virtual void set_likelihood(std::shared_ptr like_) = // 0; - // Set the prior model for the current hierarchy. Implemented in the - // BaseHierarchy class + // Set the prior model for the current hierarchy. // virtual void set_prior(std::shared_ptr prior_) = 0; - //! Set the update algorithm for the current hierarchy. Implemented in the - //! BaseHierarchy class + //! Set the update algorithm for the current hierarchy virtual void set_updater(std::shared_ptr updater_) = 0; - //! Returns (a pointer to) the likelihood for the current hierarchy. - //! Implemented in the BaseHierarchy class + //! Returns (a pointer to) the likelihood for the current hierarchy virtual std::shared_ptr get_likelihood() = 0; - //! Returns (a pointer to) the prior model for the current hierarchy. - //! Implemented in the BaseHierarchy class + //! Returns (a pointer to) the prior model for the current hierarchy virtual std::shared_ptr get_prior() = 0; //! Default destructor @@ -260,8 +255,7 @@ class AbstractHierarchy { //! Returns whether the hierarchy represents a conjugate model or not virtual bool is_conjugate() const = 0; - //! Sets the (pointer to) the dataset in the cluster. Implemented in - //! BaseHierarchy + //! Sets the (pointer to) the dataset in the cluster virtual void set_dataset(const Eigen::MatrixXd *const dataset) = 0; protected: From 02d4f92951a9d79f7a968fdb2178b26a1e638e7d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 14:51:50 +0200 Subject: [PATCH 229/317] Improved checks for exception handling --- src/hierarchies/abstract_hierarchy.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 7b998010d..1abdde5ff 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -81,7 +81,7 @@ class AbstractHierarchy { virtual double get_like_lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - if (is_dependent()) { + if (is_dependent() and covariate.size() != 0) { return like_lpdf(datum, covariate); } else { return like_lpdf(datum); From 3e849c9cbde0190346910c29d741a4ffc4ad86ef Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 15:05:05 +0200 Subject: [PATCH 230/317] changed tutorial algorithm --- resources/tutorial/algo.asciipb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resources/tutorial/algo.asciipb b/resources/tutorial/algo.asciipb index 027ac8d9f..d748858d4 100644 --- a/resources/tutorial/algo.asciipb +++ b/resources/tutorial/algo.asciipb @@ -1,6 +1,6 @@ ##### GENERIC SETTINGS FOR ALL ALGORITHMS ##### # Algorithm ID string, e.g. "Neal2" -algo_id: "Neal8" +algo_id: "Neal3" # RNG initial seed: any nonnegative integer rng_seed: 20201124 From ddebf577473314638646614c83c6a0f1050e4322 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 15:09:11 +0200 Subject: [PATCH 231/317] Removed empty file --- src/hierarchies/updaters/CMakeLists.txt | 2 +- src/hierarchies/updaters/target_lpdf_unconstrained.cc | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) delete mode 100644 src/hierarchies/updaters/target_lpdf_unconstrained.cc diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index 17e4547a9..bccadbce9 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -16,5 +16,5 @@ target_sources(bayesmix PUBLIC mala_updater.h random_walk_updater.h target_lpdf_unconstrained.h - target_lpdf_unconstrained.cc + # target_lpdf_unconstrained.cc ) diff --git a/src/hierarchies/updaters/target_lpdf_unconstrained.cc b/src/hierarchies/updaters/target_lpdf_unconstrained.cc deleted file mode 100644 index fa8fef2ba..000000000 --- a/src/hierarchies/updaters/target_lpdf_unconstrained.cc +++ /dev/null @@ -1,4 +0,0 @@ -#include "target_lpdf_unconstrained.h" - -// target_lpdf_unconstrained::target_lpdf_unconstrained( -// AbstractHierarchy *p) : parent(p) {} From 9d20bfa1c5d052c1b9afbd573810db00a42b1204 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 16:27:41 +0200 Subject: [PATCH 232/317] ignore .old folders --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 34b7e5ed9..7134946e4 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ cmake-build-debug/ # .old folders src/hierarchies/updaters/.old/ +test/.old/ From f8afa6991ffc80821fcba0a1a46ba8434d4d895d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 16:28:00 +0200 Subject: [PATCH 233/317] Cleaned code --- executables/run_mcmc.cc | 4 - src/hierarchies/abstract_hierarchy.h | 7 - src/hierarchies/base_hierarchy.h | 27 +-- src/hierarchies/fa_hierarchy.h | 15 +- .../likelihoods/abstract_likelihood.h | 25 ++- src/hierarchies/likelihoods/base_likelihood.h | 22 --- .../likelihoods/laplace_likelihood.cc | 32 ---- .../likelihoods/laplace_likelihood.h | 7 - .../likelihoods/uni_norm_likelihood.cc | 12 -- src/hierarchies/lin_reg_uni_hierarchy.h | 12 +- src/hierarchies/nnig_hierarchy.h | 12 +- src/hierarchies/nnw_hierarchy.h | 15 +- src/hierarchies/nnxig_hierarchy.h | 12 +- src/hierarchies/priors/abstract_prior_model.h | 14 +- src/hierarchies/priors/base_prior_model.h | 26 --- src/hierarchies/priors/fa_prior_model.cc | 66 +------ src/hierarchies/priors/fa_prior_model.h | 6 - src/hierarchies/priors/hyperparams.h | 2 +- src/hierarchies/priors/mnig_prior_model.cc | 21 --- src/hierarchies/priors/mnig_prior_model.h | 5 - src/hierarchies/priors/nig_prior_model.cc | 16 -- src/hierarchies/priors/nig_prior_model.h | 6 - src/hierarchies/priors/nw_prior_model.cc | 26 --- src/hierarchies/priors/nw_prior_model.h | 7 - src/hierarchies/priors/nxig_prior_model.cc | 14 -- src/hierarchies/priors/nxig_prior_model.h | 7 - src/hierarchies/updaters/CMakeLists.txt | 2 - src/hierarchies/updaters/abstract_updater.h | 1 - src/hierarchies/updaters/fa_updater.h | 8 - src/hierarchies/updaters/mnig_updater.cc | 39 ----- src/hierarchies/updaters/mnig_updater.h | 3 - src/hierarchies/updaters/nnig_updater.cc | 37 ---- src/hierarchies/updaters/nnig_updater.h | 3 - src/hierarchies/updaters/nnw_updater.cc | 39 ----- src/hierarchies/updaters/nnw_updater.h | 3 - src/hierarchies/updaters/nnxig_updater.cc | 31 ---- src/hierarchies/updaters/nnxig_updater.h | 3 - src/proto/algorithm_state.proto | 1 - src/proto/hierarchy_id.proto | 6 +- src/utils/distributions.cc | 1 - test/CMakeLists.txt | 1 - test/distributions.cc | 3 - test/priors.cc | 161 ------------------ 43 files changed, 35 insertions(+), 725 deletions(-) delete mode 100644 test/priors.cc diff --git a/executables/run_mcmc.cc b/executables/run_mcmc.cc index 9d1e4c8e0..8a47aabb8 100644 --- a/executables/run_mcmc.cc +++ b/executables/run_mcmc.cc @@ -167,10 +167,6 @@ int main(int argc, char *argv[]) { mixing->get_mutable_prior()); bayesmix::read_proto_from_file(args.get("--hier-args"), hier->get_mutable_prior()); - - // std::cout << "hier->prior: \n" - // << hier->get_mutable_prior()->DebugString() << std::endl; - hier->initialize(); // Read data matrices diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 1abdde5ff..5bd4387e0 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -51,13 +51,6 @@ class AbstractHierarchy { public: - // Set the likelihood for the current hierarchy. - // virtual void set_likelihood(std::shared_ptr like_) = - // 0; - - // Set the prior model for the current hierarchy. - // virtual void set_prior(std::shared_ptr prior_) = 0; - //! Set the update algorithm for the current hierarchy virtual void set_updater(std::shared_ptr updater_) = 0; diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 82509b44d..2bd9bcda3 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -6,7 +6,6 @@ #include #include #include -// #include #include #include "abstract_hierarchy.h" @@ -225,8 +224,7 @@ class BaseHierarchy : public AbstractHierarchy { //! Generates new state values from the centering prior distribution void sample_prior() override { - // auto hypers = prior->get_hypers_proto(); - like->set_state_from_proto(*prior->sample(/*hypers*/), false); + like->set_state_from_proto(*prior->sample(), false); }; //! Generates new state values from the centering posterior distribution @@ -404,27 +402,4 @@ class BaseHierarchy : public AbstractHierarchy { } }; -// OLD STUFF -//! Returns an independent, data-less copy of this object -// std::shared_ptr deep_clone() const override { -// auto out = std::make_shared(static_cast(*this)); - -// out->clear_data(); -// out->clear_summary_statistics(); - -// out->create_empty_prior(); -// std::shared_ptr new_prior(prior->New()); -// new_prior->CopyFrom(*prior.get()); -// out->get_mutable_prior()->CopyFrom(*new_prior.get()); - -// out->create_empty_hypers(); -// auto curr_hypers_proto = get_hypers_proto(); -// out->set_hypers_from_proto(*curr_hypers_proto.get()); -// out->initialize(); -// return out; -// } - -// TODO: Move definitions outside the class to improve code cleaness - #endif // BAYESMIX_HIERARCHIES_BASE_HIERARCHY_H_ diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index 2fb0bd7a3..063be1d45 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -1,21 +1,11 @@ #ifndef BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ -// #include - -// #include -// #include -// #include - -// #include "algorithm_state.pb.h" -// #include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "src/utils/distributions.h" -// #include "hierarchy_prior.pb.h" - #include "base_hierarchy.h" +#include "hierarchy_id.pb.h" #include "likelihoods/fa_likelihood.h" #include "priors/fa_prior_model.h" +#include "src/utils/distributions.h" #include "updaters/fa_updater.h" class FAHierarchy @@ -40,7 +30,6 @@ class FAHierarchy State::FA state; state.mu = hypers.mutilde; state.psi = hypers.beta / (hypers.alpha0 + 1.); - // state.eta = Eigen::MatrixXd::Zero(hypers.card, hypers.q); state.lambda = Eigen::MatrixXd::Zero(dim, hypers.q); state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); like->set_state(state); diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index a1739c689..8239d25c7 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -14,8 +14,7 @@ class AbstractLikelihood { //! Default destructor virtual ~AbstractLikelihood() = default; - //! Returns an independent, data-less copy of this object. Implemented in - //! BaseLikelihood + //! Returns an independent, data-less copy of this object virtual std::shared_ptr clone() const = 0; //! Public wrapper for `compute_lpdf()` methods @@ -45,13 +44,12 @@ class AbstractLikelihood { "likelihood"); } - //! Evaluates the log likelihood over all the data in the cluster - //! given unconstrained parameter values. - //! By unconstrained parameters we mean that each entry of - //! the parameter vector can range over (-inf, inf). - //! Usually, some kind of transformation is required from the unconstrained - //! parameterization to the actual parameterization. This version using - //! `stan::math::var` type is required for Stan automatic aifferentiation. + //! This version using `stan::math::var` type is required for Stan automatic + //! differentiation. Evaluates the log likelihood over all the data in the + //! cluster given unconstrained parameter values. By unconstrained parameters + //! we mean that each entry of the parameter vector can range over (-inf, + //! inf). Usually, some kind of transformation is required from the + //! unconstrained parameterization to the actual parameterization. //! @param unconstrained_params vector collecting the unconstrained //! parameters //! @return The evaluation of the log likelihood over all data in the cluster @@ -85,21 +83,18 @@ class AbstractLikelihood { virtual void set_state_from_unconstrained( const Eigen::VectorXd &unconstrained_state) = 0; - //! Writes current state to a Protobuf message by pointer. Implemented in - //! BaseLikelihood + //! Writes current state to a Protobuf message by pointer virtual void write_state_to_proto(google::protobuf::Message *out) const = 0; //! Sets the (pointer to) the dataset in the cluster virtual void set_dataset(const Eigen::MatrixXd *const dataset) = 0; - //! Adds a datum and its index to the likelihood. Implemented in - //! BaseLikelihood + //! Adds a datum and its index to the likelihood virtual void add_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; - //! Removes a datum and its index from the likelihood. Implemented in - //! BaseLikelihood + //! Removes a datum and its index from the likelihood virtual void remove_datum( const int id, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) = 0; diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index aa0f4f77f..591da527a 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -86,7 +86,6 @@ class BaseLikelihood : public AbstractLikelihood { //! Returns a vector storing the state in its unconstrained form Eigen::VectorXd get_unconstrained_state() override { return internal::get_unconstrained_state(state, 0); - // return state.get_unconstrained(); } //! Updates the state of the likelihood with the object given as input @@ -96,7 +95,6 @@ class BaseLikelihood : public AbstractLikelihood { void set_state_from_unconstrained( const Eigen::VectorXd &unconstrained_state) override { internal::set_state_from_unconstrained(state, unconstrained_state, 0); - // state.set_from_unconstrained(unconstrained_state); } //! Sets the (pointer to) the dataset in the cluster @@ -166,8 +164,6 @@ void BaseLikelihood::add_datum( const Eigen::RowVectorXd &covariate) { assert(cluster_data_idx.find(id) == cluster_data_idx.end()); set_card(++card); - // card += 1; - // log_card = std::log(card); static_cast(this)->update_summary_statistics(datum, covariate, true); cluster_data_idx.insert(id); @@ -221,22 +217,4 @@ Eigen::VectorXd BaseLikelihood::lpdf_grid( return lpdf; } -// OLD STUFF -// The unconstrained parameters are mean and log(var) - -// double cluster_lpdf_from_unconstrained( -// Eigen::VectorXd unconstrained_params) const override { -// return static_cast(*this) -// .template cluster_lpdf_from_unconstrained( -// unconstrained_params); -// } - -// stan::math::var cluster_lpdf_from_unconstrained( -// Eigen::Matrix -// unconstrained_params) const override { -// return static_cast(*this) -// .template cluster_lpdf_from_unconstrained( -// unconstrained_params); -// } - #endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/laplace_likelihood.cc b/src/hierarchies/likelihoods/laplace_likelihood.cc index 3c29632c8..bbd6c799c 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.cc +++ b/src/hierarchies/likelihoods/laplace_likelihood.cc @@ -5,20 +5,6 @@ double LaplaceLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { datum(0), state.mean, stan::math::sqrt(state.var / 2.0)); } -// void LaplaceLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, -// bool add) { -// if (add) { -// sum_abs_diff_curr += std::abs(state.mean - datum(0, 0)); -// cluster_data_values.push_back(datum); -// } else { -// sum_abs_diff_curr -= std::abs(state.mean - datum(0, 0)); -// auto it = std::find(cluster_data_values.begin(), -// cluster_data_values.end(), -// datum); -// cluster_data_values.erase(it); -// } -// } - void LaplaceLikelihood::set_state_from_proto( const google::protobuf::Message &state_, bool update_card) { auto &statecast = downcast_state(state_); @@ -34,21 +20,3 @@ LaplaceLikelihood::get_state_proto() const { out->mutable_uni_ls_state()->set_var(state.var); return out; } - -// void LaplaceLikelihood::clear_summary_statistics() { -// cluster_data_values.clear(); -// sum_abs_diff_curr = 0; -// sum_abs_diff_prop = 0; -// } - -// double UniNormLikelihood::cluster_lpdf_from_unconstrained( -// Eigen::VectorXd unconstrained_params) { -// assert(unconstrained_params.size() == 2); -// double mean = unconstrained_params(0); -// double var = std::exp(unconstrained_params(1)); -// double out = -(data_sum_squares - 2 * mean * data_sum + card * mean * -// mean) / -// (2 * var); -// out -= card * 0.5 * std::log(stan::math::TWO_PI * var); -// return out; -// } diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 84d5a200a..15e2124ed 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -47,13 +47,6 @@ class LaplaceLikelihood void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override { return; }; - - //! Set of values of data points belonging to this cluster - // std::list cluster_data_values; - //! Sum of absolute differences for current params - // double sum_abs_diff_curr = 0; - //! Sum of absolute differences for proposal params - // double sum_abs_diff_prop = 0; }; #endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_LAPLACE_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index bda296076..b0ceada1f 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -35,15 +35,3 @@ void UniNormLikelihood::clear_summary_statistics() { data_sum = 0; data_sum_squares = 0; } - -// double UniNormLikelihood::cluster_lpdf_from_unconstrained( -// Eigen::VectorXd unconstrained_params) { -// assert(unconstrained_params.size() == 2); -// double mean = unconstrained_params(0); -// double var = std::exp(unconstrained_params(1)); -// double out = -(data_sum_squares - 2 * mean * data_sum + card * mean * -// mean) / -// (2 * var); -// out -= card * 0.5 * std::log(stan::math::TWO_PI * var); -// return out; -// } diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index 1d0fa1d01..ae449a103 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -1,18 +1,8 @@ #ifndef BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_LIN_REG_UNI_HIERARCHY_H_ -// #include - -// #include -// #include -// #include - -// #include "algorithm_state.pb.h" -// #include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -// #include "hierarchy_prior.pb.h" - #include "base_hierarchy.h" +#include "hierarchy_id.pb.h" #include "likelihoods/uni_lin_reg_likelihood.h" #include "priors/mnig_prior_model.h" #include "updaters/mnig_updater.h" diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index 2aa273855..b70e0eeac 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -1,18 +1,8 @@ #ifndef BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ -// #include - -// #include -// #include -// #include - -// #include "algorithm_state.pb.h" -// #include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -// #include "hierarchy_prior.pb.h" - #include "base_hierarchy.h" +#include "hierarchy_id.pb.h" #include "likelihoods/uni_norm_likelihood.h" #include "priors/nig_prior_model.h" #include "updaters/nnig_updater.h" diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 4073ef053..35e5920af 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -1,21 +1,11 @@ #ifndef BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_NNW_HIERARCHY_H_ -// #include - -// #include -// #include -// #include - -// #include "algorithm_state.pb.h" -// #include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -#include "src/utils/distributions.h" -// #include "hierarchy_prior.pb.h" - #include "base_hierarchy.h" +#include "hierarchy_id.pb.h" #include "likelihoods/multi_norm_likelihood.h" #include "priors/nw_prior_model.h" +#include "src/utils/distributions.h" #include "updaters/nnw_updater.h" class NNWHierarchy @@ -60,7 +50,6 @@ class NNWHierarchy unsigned int dim = like->get_dim(); double nu_n = params.deg_free() - dim + 1; double coeff = (params.var_scaling() + 1) / (params.var_scaling() * nu_n); - // Eigen::MatrixXd scale = bayesmix::to_eigen(params.scale()); Eigen::MatrixXd scale_chol = Eigen::LLT(bayesmix::to_eigen(params.scale())) .matrixU(); diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h index a99a27909..a0de895fe 100644 --- a/src/hierarchies/nnxig_hierarchy.h +++ b/src/hierarchies/nnxig_hierarchy.h @@ -1,18 +1,8 @@ #ifndef BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_NNXIG_HIERARCHY_H_ -// #include - -// #include -// #include -// #include - -// #include "algorithm_state.pb.h" -// #include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -// #include "hierarchy_prior.pb.h" - #include "base_hierarchy.h" +#include "hierarchy_id.pb.h" #include "likelihoods/uni_norm_likelihood.h" #include "priors/nxig_prior_model.h" #include "updaters/nnxig_updater.h" diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 71a75d402..412c51ae3 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -21,12 +21,10 @@ class AbstractPriorModel { //! Default destructor virtual ~AbstractPriorModel() = default; - //! Returns an independent, data-less copy of this object. Implemented in - //! BasePriorModel + //! Returns an independent, data-less copy of this object virtual std::shared_ptr clone() const = 0; - //! Returns an independent, data-less deep copy of this object. Implemented - //! in BasePriorModel + //! Returns an independent, data-less deep copy of this object virtual std::shared_ptr deep_clone() const = 0; //! Evaluates the log likelihood for the prior model, given the state of the @@ -65,8 +63,10 @@ class AbstractPriorModel { } //! Sampling from the prior model - //! @param use_post_hypers It is a `bool` which decides whether to use prior - //! or posterior parameters + //! @param hier_hypers A pointer to a + //! `bayesmix::AlgorithmState::hierarchyHypers` object, which defines the + //! parameters to use for the sampling. The default behaviour (i.e. + //! `hier_hypers = nullptr`) uses prior hyperparameters //! @return A Protobuf message storing the state sampled from the prior model // virtual std::shared_ptr sample( // bool use_post_hypers) = 0; @@ -86,7 +86,7 @@ class AbstractPriorModel { const google::protobuf::Message &state_) = 0; //! Writes current values of the hyperparameters to a Protobuf message by - //! pointer. Implemented in BasePriorModel + //! pointer virtual void write_hypers_to_proto(google::protobuf::Message *out) const = 0; //! Writes current value of hyperparameters to a Protobuf message and diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 04e14dc50..efdd70e30 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -66,14 +66,6 @@ class BasePriorModel : public AbstractPriorModel { //! Returns the struct of the current prior hyperparameters HyperParams get_hypers() const { return *hypers; }; - //! Returns the struct of the current posterior hyperparameters - // HyperParams get_posterior_hypers() const { return post_hypers; } - - //! Updates the current value of the posterior hyperparameters - // void set_posterior_hypers(const HyperParams &_post_hypers) { - // post_hypers = _post_hypers; - // }; - //! Writes current values of the hyperparameters to a Protobuf message by //! pointer void write_hypers_to_proto(google::protobuf::Message *out) const override; @@ -116,9 +108,6 @@ class BasePriorModel : public AbstractPriorModel { //! Container for prior hyperparameters values std::shared_ptr hypers = std::make_shared(); - //! Container for posterior hyperparameters values - // HyperParams post_hypers; - //! Pointer to a Protobuf prior object for this class std::shared_ptr prior; }; @@ -185,19 +174,4 @@ void BasePriorModel::check_prior_is_set() const { } } -// OLD STUFF -// double lpdf_from_unconstrained( -// Eigen::VectorXd unconstrained_params) const override { -// return static_cast(*this) -// .template lpdf_from_unconstrained(unconstrained_params); -// } - -// stan::math::var lpdf_from_unconstrained( -// Eigen::Matrix -// unconstrained_params) const override { -// return static_cast(*this) -// .template lpdf_from_unconstrained( -// unconstrained_params); -// } - #endif // BAYESMIX_HIERARCHIES_PRIORS_BASE_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index c8ca7a95c..99271094a 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -3,13 +3,17 @@ double FAPriorModel::lpdf(const google::protobuf::Message &state_) { // Downcast state auto &state = downcast_state(state_).fa_state(); + // Proto2Eigen conversion Eigen::VectorXd mu = bayesmix::to_eigen(state.mu()); Eigen::VectorXd psi = bayesmix::to_eigen(state.psi()); + // Eigen::MatrixXd eta = bayesmix::to_eigen(state.eta()); Eigen::MatrixXd lambda = bayesmix::to_eigen(state.lambda()); + // Initialize lpdf value double target = 0.; + // Compute lpdf for (size_t j = 0; j < dim; j++) { target += @@ -20,11 +24,7 @@ double FAPriorModel::lpdf(const google::protobuf::Message &state_) { target += stan::math::normal_lpdf(lambda(j, i), 0, 1); } } - // for (size_t i = 0; i < eta.rows(); i++) { - // for (size_t j = 0; j < hypers->q; j++) { - // target += stan::math::normal_lpdf(eta(i, j), 0, 1); - // } - // } + // Return lpdf contribution return target; } @@ -62,55 +62,6 @@ std::shared_ptr FAPriorModel::sample( return std::make_shared(state); } -// std::shared_ptr FAPriorModel::sample( -// bool use_post_hypers) { -// // Random seed -// auto &rng = bayesmix::Rng::Instance().get(); - -// // Select params to use -// Hyperparams::FA params = use_post_hypers ? post_hypers : *hypers; - -// // Compute output state -// State::FA out; -// out.mu = params.mutilde; -// out.psi = params.beta / (params.alpha0 + 1.); -// // out.eta = Eigen::MatrixXd::Zero(params.card, params.q); -// out.lambda = Eigen::MatrixXd::Zero(dim, params.q); -// for (size_t j = 0; j < dim; j++) { -// out.mu[j] = -// stan::math::normal_rng(params.mutilde[j], sqrt(params.phi), rng); - -// out.psi[j] = stan::math::inv_gamma_rng(params.alpha0, params.beta[j], -// rng); - -// for (size_t i = 0; i < params.q; i++) { -// out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); -// } -// } -// // for (size_t i = 0; i < params.card; i++) { -// // for (size_t j = 0; j < params.q; j++) { -// // out.eta(i, j) = stan::math::normal_rng(0, 1, rng); -// // } -// // } - -// // Questi conti non li passo al proto, attenzione !!! -// // out.psi_inverse = out.psi.cwiseInverse().asDiagonal(); -// // compute_wood_factors(out.cov_wood, out.cov_logdet, out.lambda, -// // out.psi_inverse); - -// // Eigen2Proto conversion -// bayesmix::AlgorithmState::ClusterState state; -// bayesmix::to_proto(out.mu, state.mutable_fa_state()->mutable_mu()); -// bayesmix::to_proto(out.psi, state.mutable_fa_state()->mutable_psi()); -// // bayesmix::to_proto(out.eta, state.mutable_fa_state()->mutable_eta()); -// bayesmix::to_proto(out.lambda, -// state.mutable_fa_state()->mutable_lambda()); return -// std::make_shared(state); - -// // MANCA PSI_INVERSE E GLI OUTPUT DA COMPUTE_WOOD_FACTORS !!! BISOGNA -// // CAMBIARE IL PROTO -// } - void FAPriorModel::update_hypers( const std::vector &states) { auto &rng = bayesmix::Rng::Instance().get(); @@ -129,7 +80,6 @@ void FAPriorModel::set_hypers_from_proto( hypers->beta = bayesmix::to_eigen(hyperscast.beta()); hypers->phi = hyperscast.phi(); hypers->q = hyperscast.q(); - // hypers->card = hyperscast.card(); } std::shared_ptr @@ -140,7 +90,6 @@ FAPriorModel::get_hypers_proto() const { hypers_.set_alpha0(hypers->alpha0); hypers_.set_phi(hypers->phi); hypers_.set_q(hypers->q); - // hypers_.set_card(hypers->card); auto out = std::make_shared(); out->mutable_fa_state()->CopyFrom(hypers_); @@ -156,7 +105,6 @@ void FAPriorModel::initialize_hypers() { hypers->phi = prior->fixed_values().phi(); hypers->alpha0 = prior->fixed_values().alpha0(); hypers->q = prior->fixed_values().q(); - // hypers->card = prior->fixed_values().card(); // Check validity if (dim != hypers->beta.rows()) { @@ -177,9 +125,6 @@ void FAPriorModel::initialize_hypers() { if (hypers->q <= 0) { throw std::invalid_argument("Number of factors must be > 0"); } - // if (hypers->card < 0) { - // throw std::invalid_argument("Number of data must be >= 0"); - // } } else { @@ -187,6 +132,7 @@ void FAPriorModel::initialize_hypers() { } } +// TODO /* // Automatic initialization if (dim == 0) { diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index 3544e9e30..66088d5f4 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -1,13 +1,10 @@ #ifndef BAYESMIX_HIERARCHIES_PRIORS_FA_PRIOR_MODEL_H_ #define BAYESMIX_HIERARCHIES_PRIORS_FA_PRIOR_MODEL_H_ -// #include - #include #include #include -// #include "algorithm_state.pb.h" #include "base_prior_model.h" #include "hierarchy_prior.pb.h" #include "hyperparams.h" @@ -27,9 +24,6 @@ class FAPriorModel std::shared_ptr sample( ProtoHypersPtr hier_hypers = nullptr) override; - // std::shared_ptr sample( - // bool use_post_hypers) override; - void update_hypers(const std::vector &states) override; diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h index acdcd8f7a..1aca6dc4a 100644 --- a/src/hierarchies/priors/hyperparams.h +++ b/src/hierarchies/priors/hyperparams.h @@ -28,7 +28,7 @@ struct MNIG { struct FA { Eigen::VectorXd mutilde, beta; double phi, alpha0; - unsigned int /*card,*/ q; + unsigned int q; }; } // namespace Hyperparams diff --git a/src/hierarchies/priors/mnig_prior_model.cc b/src/hierarchies/priors/mnig_prior_model.cc index af629df99..215647780 100644 --- a/src/hierarchies/priors/mnig_prior_model.cc +++ b/src/hierarchies/priors/mnig_prior_model.cc @@ -32,27 +32,6 @@ std::shared_ptr MNIGPriorModel::sample( return std::make_shared(state); } -// std::shared_ptr MNIGPriorModel::sample( -// bool use_post_hypers) { -// auto &rng = bayesmix::Rng::Instance().get(); -// Hyperparams::MNIG params = use_post_hypers ? post_hypers : *hypers; - -// double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); -// Eigen::VectorXd regression_coeffs = stan::math::multi_normal_prec_rng( -// params.mean, params.var_scaling / var, rng); - -// bayesmix::AlgorithmState::ClusterState state; -// // bayesmix::Vector regression_coeffs_proto; -// bayesmix::to_proto( -// regression_coeffs, -// state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()); -// // -// state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()->CopyFrom(regression_coeffs_proto); -// state.mutable_lin_reg_uni_ls_state()->set_var(var); - -// return std::make_shared(state); -// } - void MNIGPriorModel::update_hypers( const std::vector &states) { if (prior->has_fixed_values()) { diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index 4f2cf478d..f6949fce7 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -1,13 +1,10 @@ #ifndef BAYESMIX_HIERARCHIES_PRIORS_MNIG_PRIOR_MODEL_H_ #define BAYESMIX_HIERARCHIES_PRIORS_MNIG_PRIOR_MODEL_H_ -// #include - #include #include #include -// #include "algorithm_state.pb.h" #include "base_prior_model.h" #include "hierarchy_prior.pb.h" #include "hyperparams.h" @@ -26,8 +23,6 @@ class MNIGPriorModel : public BasePriorModel sample( ProtoHypersPtr hier_hypers = nullptr) override; - // std::shared_ptr sample( - // bool use_post_hypers) override; void update_hypers(const std::vector &states) override; diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index 7690ab607..248372f49 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -84,28 +84,12 @@ double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } -// std::shared_ptr NIGPriorModel::sample( -// bool use_post_hypers) { -// auto &rng = bayesmix::Rng::Instance().get(); -// Hyperparams::NIG params = use_post_hypers ? post_hypers : *hypers; -// double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); -// double mean = -// stan::math::normal_rng(params.mean, sqrt(var / params.var_scaling), -// rng); - -// bayesmix::AlgorithmState::ClusterState state; -// state.mutable_uni_ls_state()->set_mean(mean); -// state.mutable_uni_ls_state()->set_var(var); -// return std::make_shared(state); -// }; - std::shared_ptr NIGPriorModel::sample( ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); auto params = (hier_hypers) ? hier_hypers->nnig_state() : get_hypers_proto()->nnig_state(); - // Hyperparams::NIG params = use_post_hypers ? post_hypers : *hypers; double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); double mean = stan::math::normal_rng(params.mean(), sqrt(var / params.var_scaling()), rng); diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index cb0362c81..2c4357955 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -1,14 +1,11 @@ #ifndef BAYESMIX_HIERARCHIES_PRIORS_NIG_PRIOR_MODEL_H_ #define BAYESMIX_HIERARCHIES_PRIORS_NIG_PRIOR_MODEL_H_ -// #include - #include #include #include #include -// #include "algorithm_state.pb.h" #include "base_prior_model.h" #include "hierarchy_prior.pb.h" #include "hyperparams.h" @@ -40,9 +37,6 @@ class NIGPriorModel : public BasePriorModel sample( - // bool use_post_hypers) override; - std::shared_ptr sample( ProtoHypersPtr hier_hypers = nullptr) override; diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc index 256bb8e1a..50fd2c0a0 100644 --- a/src/hierarchies/priors/nw_prior_model.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -147,32 +147,6 @@ std::shared_ptr NWPriorModel::sample( return std::make_shared(state); }; -// std::shared_ptr NWPriorModel::sample( -// bool use_post_hypers) { -// auto &rng = bayesmix::Rng::Instance().get(); - -// Hyperparams::NW params = use_post_hypers ? post_hypers : *hypers; - -// Eigen::MatrixXd tau_new = -// stan::math::wishart_rng(params.deg_free, params.scale, rng); - -// // Update state -// State::MultiLS out; -// out.mean = stan::math::multi_normal_prec_rng( -// params.mean, tau_new * params.var_scaling, rng); -// write_prec_to_state(tau_new, &out); - -// // Make output state -// bayesmix::AlgorithmState::ClusterState state; -// bayesmix::to_proto(out.mean, -// state.mutable_multi_ls_state()->mutable_mean()); -// bayesmix::to_proto(out.prec, -// state.mutable_multi_ls_state()->mutable_prec()); -// bayesmix::to_proto(out.prec_chol, -// state.mutable_multi_ls_state()->mutable_prec_chol()); -// return std::make_shared(state); -// }; - void NWPriorModel::update_hypers( const std::vector &states) { auto &rng = bayesmix::Rng::Instance().get(); diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h index cb5b695ba..e7c0f06dd 100644 --- a/src/hierarchies/priors/nw_prior_model.h +++ b/src/hierarchies/priors/nw_prior_model.h @@ -1,15 +1,11 @@ #ifndef BAYESMIX_HIERARCHIES_PRIORS_NW_PRIOR_MODEL_H_ #define BAYESMIX_HIERARCHIES_PRIORS_NW_PRIOR_MODEL_H_ -// #include - -// #include #include #include #include #include -// #include "algorithm_state.pb.h" #include "base_prior_model.h" #include "hierarchy_prior.pb.h" #include "hyperparams.h" @@ -26,9 +22,6 @@ class NWPriorModel : public BasePriorModel sample( ProtoHypersPtr hier_hypers = nullptr) override; - // std::shared_ptr sample( - // bool use_post_hypers) override; - void update_hypers(const std::vector &states) override; diff --git a/src/hierarchies/priors/nxig_prior_model.cc b/src/hierarchies/priors/nxig_prior_model.cc index f2704ca1e..52dacc703 100644 --- a/src/hierarchies/priors/nxig_prior_model.cc +++ b/src/hierarchies/priors/nxig_prior_model.cc @@ -45,20 +45,6 @@ std::shared_ptr NxIGPriorModel::sample( return std::make_shared(state); }; -// std::shared_ptr NxIGPriorModel::sample( -// bool use_post_hypers) { -// auto &rng = bayesmix::Rng::Instance().get(); -// Hyperparams::NxIG params = use_post_hypers ? post_hypers : *hypers; - -// double var = stan::math::inv_gamma_rng(params.shape, params.scale, rng); -// double mean = stan::math::normal_rng(params.mean, sqrt(params.var), rng); - -// bayesmix::AlgorithmState::ClusterState state; -// state.mutable_uni_ls_state()->set_mean(mean); -// state.mutable_uni_ls_state()->set_var(var); -// return std::make_shared(state); -// }; - void NxIGPriorModel::update_hypers( const std::vector &states) { if (prior->has_fixed_values()) { diff --git a/src/hierarchies/priors/nxig_prior_model.h b/src/hierarchies/priors/nxig_prior_model.h index bfc264736..86a517c58 100644 --- a/src/hierarchies/priors/nxig_prior_model.h +++ b/src/hierarchies/priors/nxig_prior_model.h @@ -1,15 +1,11 @@ #ifndef BAYESMIX_HIERARCHIES_PRIORS_NXIG_PRIOR_MODEL_H_ #define BAYESMIX_HIERARCHIES_PRIORS_NXIG_PRIOR_MODEL_H_ -// #include - -// #include #include #include #include #include -// #include "algorithm_state.pb.h" #include "base_prior_model.h" #include "hierarchy_prior.pb.h" #include "hyperparams.h" @@ -26,9 +22,6 @@ class NxIGPriorModel : public BasePriorModel sample( - // bool use_post_hypers) override; - std::shared_ptr sample( ProtoHypersPtr hier_hypers = nullptr) override; diff --git a/src/hierarchies/updaters/CMakeLists.txt b/src/hierarchies/updaters/CMakeLists.txt index bccadbce9..dbc5e7e9a 100644 --- a/src/hierarchies/updaters/CMakeLists.txt +++ b/src/hierarchies/updaters/CMakeLists.txt @@ -1,6 +1,5 @@ target_sources(bayesmix PUBLIC abstract_updater.h - # conjugate_updater.h semi_conjugate_updater.h nnig_updater.h nnig_updater.cc @@ -16,5 +15,4 @@ target_sources(bayesmix PUBLIC mala_updater.h random_walk_updater.h target_lpdf_unconstrained.h - # target_lpdf_unconstrained.cc ) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index 2a4c32b91..b1c8b849c 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -3,7 +3,6 @@ #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" -// #include "src/hierarchies/updaters/target_lpdf_unconstrained.h" class AbstractUpdater { public: diff --git a/src/hierarchies/updaters/fa_updater.h b/src/hierarchies/updaters/fa_updater.h index b1ffcb3ae..671cbaae5 100644 --- a/src/hierarchies/updaters/fa_updater.h +++ b/src/hierarchies/updaters/fa_updater.h @@ -24,14 +24,6 @@ class FAUpdater : public AbstractUpdater { const FALikelihood& like); void sample_psi(State::FA& state, const Hyperparams::FA& hypers, const FALikelihood& like); - // void sample_eta(State::FA & state, const Hyperparams::FA & hypers, const - // Eigen::MatrixXd * dataset_ptr, const std::set & cluster_data_idx); - // void sample_mu(State::FA & state, const Hyperparams::FA & hypers, const - // Eigen::VectorXd & data_sum); void sample_lambda(State::FA & state, const - // Hyperparams::FA & hypers, const Eigen::MatrixXd * dataset_ptr, const - // std::set & cluster_data_idx, size_t dim); void sample_psi(State::FA & - // state, const Hyperparams::FA & hypers, const Eigen::MatrixXd * - // dataset_ptr, const std::set & cluster_data_idx, size_t dim); }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_FA_UPDATER_H_ diff --git a/src/hierarchies/updaters/mnig_updater.cc b/src/hierarchies/updaters/mnig_updater.cc index d33ffed90..beb822c1a 100644 --- a/src/hierarchies/updaters/mnig_updater.cc +++ b/src/hierarchies/updaters/mnig_updater.cc @@ -26,7 +26,6 @@ AbstractUpdater::ProtoHypersPtr MNIGUpdater::compute_posterior_hypers( var_scaling = covar_sum_squares + hypers.var_scaling; auto llt = var_scaling.llt(); - // var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, dim)); mean = llt.solve(mixed_prod + hypers.var_scaling * hypers.mean); shape = hypers.shape + 0.5 * card; scale = hypers.scale + @@ -43,41 +42,3 @@ AbstractUpdater::ProtoHypersPtr MNIGUpdater::compute_posterior_hypers( out.mutable_lin_reg_uni_state()->set_scale(scale); return std::make_shared(out); } - -// void MNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, -// AbstractPriorModel& prior) { -// // Likelihood and Prior downcast -// auto& likecast = downcast_likelihood(like); -// auto& priorcast = downcast_prior(prior); - -// // Getting required quantities from likelihood and prior -// int card = likecast.get_card(); -// unsigned int dim = likecast.get_dim(); -// double data_sum_squares = likecast.get_data_sum_squares(); -// Eigen::MatrixXd covar_sum_squares = likecast.get_covar_sum_squares(); -// Eigen::MatrixXd mixed_prod = likecast.get_mixed_prod(); -// auto hypers = priorcast.get_hypers(); - -// // No update possible -// if (card == 0) { -// priorcast.set_posterior_hypers(hypers); -// return; -// } - -// // Compute posterior hyperparameters -// Hyperparams::MNIG post_params; -// post_params.var_scaling = covar_sum_squares + hypers.var_scaling; -// auto llt = post_params.var_scaling.llt(); -// post_params.var_scaling_inv = llt.solve(Eigen::MatrixXd::Identity(dim, -// dim)); post_params.mean = llt.solve(mixed_prod + hypers.var_scaling * -// hypers.mean); post_params.shape = hypers.shape + 0.5 * card; -// post_params.scale = -// hypers.scale + -// 0.5 * (data_sum_squares + -// hypers.mean.transpose() * hypers.var_scaling * hypers.mean - -// post_params.mean.transpose() * post_params.var_scaling * -// post_params.mean); - -// priorcast.set_posterior_hypers(post_params); -// return; -// }; diff --git a/src/hierarchies/updaters/mnig_updater.h b/src/hierarchies/updaters/mnig_updater.h index fbb1fe5ab..4f7bb2d35 100644 --- a/src/hierarchies/updaters/mnig_updater.h +++ b/src/hierarchies/updaters/mnig_updater.h @@ -15,9 +15,6 @@ class MNIGUpdater ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) override; - - // void compute_posterior_hypers(AbstractLikelihood& like, - // AbstractPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_MNIG_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnig_updater.cc b/src/hierarchies/updaters/nnig_updater.cc index d09cd804e..c73ae39c1 100644 --- a/src/hierarchies/updaters/nnig_updater.cc +++ b/src/hierarchies/updaters/nnig_updater.cc @@ -22,7 +22,6 @@ AbstractUpdater::ProtoHypersPtr NNIGUpdater::compute_posterior_hypers( // Compute posterior hyperparameters double mean, var_scaling, shape, scale; - // Hyperparams::NIG post_params; double y_bar = data_sum / (1.0 * card); // sample mean double ss = data_sum_squares - card * y_bar * y_bar; mean = (hypers.var_scaling * hypers.mean + data_sum) / @@ -41,39 +40,3 @@ AbstractUpdater::ProtoHypersPtr NNIGUpdater::compute_posterior_hypers( out.mutable_nnig_state()->set_scale(scale); return std::make_shared(out); } - -// void NNIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, -// AbstractPriorModel& prior) { -// // Likelihood and Prior downcast -// auto& likecast = downcast_likelihood(like); -// auto& priorcast = downcast_prior(prior); - -// // Getting required quantities from likelihood and prior -// int card = likecast.get_card(); -// double data_sum = likecast.get_data_sum(); -// double data_sum_squares = likecast.get_data_sum_squares(); -// auto hypers = priorcast.get_hypers(); - -// // No update possible -// if (card == 0) { -// priorcast.set_posterior_hypers(hypers); -// return; -// } - -// // Compute posterior hyperparameters -// Hyperparams::NIG post_params; -// double y_bar = data_sum / (1.0 * card); // sample mean -// double ss = data_sum_squares - card * y_bar * y_bar; -// post_params.mean = (hypers.var_scaling * hypers.mean + data_sum) / -// (hypers.var_scaling + card); -// post_params.var_scaling = hypers.var_scaling + card; -// post_params.shape = hypers.shape + 0.5 * card; -// post_params.scale = hypers.scale + 0.5 * ss + -// 0.5 * hypers.var_scaling * card * (y_bar - -// hypers.mean) * -// (y_bar - hypers.mean) / (card + -// hypers.var_scaling); - -// priorcast.set_posterior_hypers(post_params); -// return; -// }; diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index 7864f63ce..8a7f52b2d 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -15,9 +15,6 @@ class NNIGUpdater ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) override; - - // void compute_posterior_hypers(AbstractLikelihood& like, - // AbstractPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_NNIG_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnw_updater.cc b/src/hierarchies/updaters/nnw_updater.cc index 455f4731a..f265f84f4 100644 --- a/src/hierarchies/updaters/nnw_updater.cc +++ b/src/hierarchies/updaters/nnw_updater.cc @@ -38,7 +38,6 @@ AbstractUpdater::ProtoHypersPtr NNWUpdater::compute_posterior_hypers( (mubar - hypers.mean) * (mubar - hypers.mean).transpose(); scale_inv = tau_temp + hypers.scale_inv; scale = stan::math::inverse_spd(scale_inv); - // scale_chol = Eigen::LLT(scale).matrixU(); // Proto conversion ProtoHypers out; @@ -48,41 +47,3 @@ AbstractUpdater::ProtoHypersPtr NNWUpdater::compute_posterior_hypers( bayesmix::to_proto(scale, out.mutable_nnw_state()->mutable_scale()); return std::make_shared(out); } - -// void NNWUpdater::compute_posterior_hypers(AbstractLikelihood& like, -// AbstractPriorModel& prior) { -// // Likelihood and Prior downcast -// auto& likecast = downcast_likelihood(like); -// auto& priorcast = downcast_prior(prior); - -// // Getting required quantities from likelihood and prior -// int card = likecast.get_card(); -// Eigen::VectorXd data_sum = likecast.get_data_sum(); -// Eigen::MatrixXd data_sum_squares = likecast.get_data_sum_squares(); -// auto hypers = priorcast.get_hypers(); - -// // No update possible -// if (card == 0) { -// priorcast.set_posterior_hypers(hypers); -// return; -// } - -// // Compute posterior hyperparameters -// Hyperparams::NW post_params; -// post_params.var_scaling = hypers.var_scaling + card; -// post_params.deg_free = hypers.deg_free + card; -// Eigen::VectorXd mubar = data_sum.array() / card; // sample mean -// post_params.mean = (hypers.var_scaling * hypers.mean + card * mubar) / -// (hypers.var_scaling + card); -// // Compute tau_n -// Eigen::MatrixXd tau_temp = -// data_sum_squares - card * mubar * mubar.transpose(); -// tau_temp += (card * hypers.var_scaling / (card + hypers.var_scaling)) * -// (mubar - hypers.mean) * (mubar - hypers.mean).transpose(); -// post_params.scale_inv = tau_temp + hypers.scale_inv; -// post_params.scale = stan::math::inverse_spd(post_params.scale_inv); -// post_params.scale_chol = -// Eigen::LLT(post_params.scale).matrixU(); -// priorcast.set_posterior_hypers(post_params); -// return; -// }; diff --git a/src/hierarchies/updaters/nnw_updater.h b/src/hierarchies/updaters/nnw_updater.h index a99567a2d..bd7977acd 100644 --- a/src/hierarchies/updaters/nnw_updater.h +++ b/src/hierarchies/updaters/nnw_updater.h @@ -15,9 +15,6 @@ class NNWUpdater ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) override; - - // void compute_posterior_hypers(AbstractLikelihood& like, - // AbstractPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_NNW_UPDATER_H_ diff --git a/src/hierarchies/updaters/nnxig_updater.cc b/src/hierarchies/updaters/nnxig_updater.cc index b950f3e7c..84f91c73c 100644 --- a/src/hierarchies/updaters/nnxig_updater.cc +++ b/src/hierarchies/updaters/nnxig_updater.cc @@ -39,34 +39,3 @@ AbstractUpdater::ProtoHypersPtr NNxIGUpdater::compute_posterior_hypers( out.mutable_nnxig_state()->set_scale(scale); return std::make_shared(out); } - -// void NNxIGUpdater::compute_posterior_hypers(AbstractLikelihood& like, -// AbstractPriorModel& prior) { -// // Likelihood and Prior downcast -// auto& likecast = downcast_likelihood(like); -// auto& priorcast = downcast_prior(prior); - -// // Getting required quantities from likelihood and prior -// auto state = likecast.get_state(); -// int card = likecast.get_card(); -// double data_sum = likecast.get_data_sum(); -// double data_sum_squares = likecast.get_data_sum_squares(); -// auto hypers = priorcast.get_hypers(); - -// // No update possible -// if (card == 0) { -// priorcast.set_posterior_hypers(hypers); -// } - -// // Compute posterior hyperparameters -// Hyperparams::NxIG post_params; -// double var_y = data_sum_squares - 2 * state.mean * data_sum + -// card * state.mean * state.mean; -// post_params.mean = (hypers.var * data_sum + state.var * hypers.mean) / -// (card * hypers.var + state.var); -// post_params.var = (state.var * hypers.var) / (card * hypers.var + -// state.var); post_params.shape = hypers.shape + 0.5 * card; -// post_params.scale = hypers.scale + 0.5 * var_y; -// priorcast.set_posterior_hypers(post_params); -// return; -// }; diff --git a/src/hierarchies/updaters/nnxig_updater.h b/src/hierarchies/updaters/nnxig_updater.h index a595e544b..52d8f0a45 100644 --- a/src/hierarchies/updaters/nnxig_updater.h +++ b/src/hierarchies/updaters/nnxig_updater.h @@ -13,9 +13,6 @@ class NNxIGUpdater ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood &like, AbstractPriorModel &prior) override; - - // void compute_posterior_hypers(AbstractLikelihood& like, - // AbstractPriorModel& prior) override; }; #endif // BAYESMIX_HIERARCHIES_UPDATERS_NNXIG_UPDATER_H_ diff --git a/src/proto/algorithm_state.proto b/src/proto/algorithm_state.proto index 8975320dc..221e0ed4e 100644 --- a/src/proto/algorithm_state.proto +++ b/src/proto/algorithm_state.proto @@ -43,7 +43,6 @@ message AlgorithmState { NWDistribution nnw_state = 3; MultiNormalIGDistribution lin_reg_uni_state = 4; NxIGDistribution nnxig_state = 5; - // LapNIGState lapnig_state = 6; FAPriorDistribution fa_state = 7; } } diff --git a/src/proto/hierarchy_id.proto b/src/proto/hierarchy_id.proto index dcd870592..a6817aa5b 100644 --- a/src/proto/hierarchy_id.proto +++ b/src/proto/hierarchy_id.proto @@ -10,7 +10,7 @@ enum HierarchyId { NNIG = 1; // Normal - Normal Inverse Gamma NNW = 2; // Normal - Normal Wishart LinRegUni = 3; // Linear Regression (univariate response) - NNxIG = 4; // Normal - Normal x Inverse Gamma - LapNIG = 5; // Laplace - Normal Inverse Gamma - FA = 6; // Factor Analysers + LapNIG = 4; // Laplace - Normal Inverse Gamma + FA = 5; // Factor Analysers + NNxIG = 6; // Normal - Normal x Inverse Gamma } diff --git a/src/utils/distributions.cc b/src/utils/distributions.cc index d4840defa..c4536c1a8 100644 --- a/src/utils/distributions.cc +++ b/src/utils/distributions.cc @@ -1,6 +1,5 @@ #include "distributions.h" -// #include #include #include #include diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8d147234c..20f7806ae 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -22,7 +22,6 @@ add_executable(test_bayesmix $ prior_models.cc hierarchies.cc lpdf.cc - # priors.cc // OLD, USEREI prior_models.cc eigen_utils.cc distributions.cc semi_hdp.cc diff --git a/test/distributions.cc b/test/distributions.cc index d5217513e..3705a5833 100644 --- a/test/distributions.cc +++ b/test/distributions.cc @@ -89,7 +89,6 @@ TEST(student_t, optimized) { Eigen::VectorXd x = Eigen::VectorXd::Ones(5); double lpdf_stan = stan::math::multi_student_t_lpdf(x, df, mean, sigma); - // std::cout << "lpdf_stan: " << lpdf_stan << std::endl; Eigen::MatrixXd sigma_inv = stan::math::inverse_spd(sigma); Eigen::MatrixXd sigma_inv_chol = @@ -100,8 +99,6 @@ TEST(student_t, optimized) { double our_lpdf = bayesmix::multi_student_t_invscale_lpdf( x, df, mean, sigma_inv_chol, logdet); - // std::cout << "our_lpdf: " << our_lpdf << std::endl; - ASSERT_LE(std::abs(our_lpdf - lpdf_stan), 0.001); } diff --git a/test/priors.cc b/test/priors.cc deleted file mode 100644 index 2b5893f27..000000000 --- a/test/priors.cc +++ /dev/null @@ -1,161 +0,0 @@ -#include -#include - -#include - -#include "algorithm_state.pb.h" -#include "src/hierarchies/nnig_hierarchy.h" -#include "src/hierarchies/nnw_hierarchy.h" -#include "src/hierarchies/nnxig_hierarchy.h" -#include "src/mixings/dirichlet_mixing.h" -#include "src/utils/proto_utils.h" - -TEST(mixing, fixed_value) { - DirichletMixing mix; - bayesmix::DPPrior prior; - double m = 2.0; - prior.mutable_fixed_value()->set_totalmass(m); - double m_state = prior.fixed_value().totalmass(); - ASSERT_DOUBLE_EQ(m, m_state); - mix.get_mutable_prior()->CopyFrom(prior); - mix.initialize(); - double m_mix = mix.get_state().totalmass; - ASSERT_DOUBLE_EQ(m, m_mix); - - std::vector> hiers(100); - unsigned int n_data = 1000; - mix.update_state(hiers, std::vector(n_data)); - double m_mix_after = mix.get_state().totalmass; - ASSERT_DOUBLE_EQ(m, m_mix_after); -} - -TEST(mixing, gamma_prior) { - DirichletMixing mix; - bayesmix::DPPrior prior; - double alpha = 1.0; - double beta = 2.0; - double m_prior = alpha / beta; - prior.mutable_gamma_prior()->mutable_totalmass_prior()->set_shape(alpha); - prior.mutable_gamma_prior()->mutable_totalmass_prior()->set_rate(beta); - mix.get_mutable_prior()->CopyFrom(prior); - mix.initialize(); - double m_mix = mix.get_state().totalmass; - ASSERT_DOUBLE_EQ(m_prior, m_mix); - - std::vector> hiers(100); - unsigned int n_data = 1000; - mix.update_state(hiers, std::vector(n_data)); - double m_mix_after = mix.get_state().totalmass; - - std::cout << " after = " << m_mix_after << std::endl; - ASSERT_TRUE(m_mix_after > m_mix); -} - -TEST(hierarchies, fixed_values) { - bayesmix::NNIGPrior prior; - bayesmix::AlgorithmState::HierarchyHypers prior_out; - prior.mutable_fixed_values()->set_mean(5.0); - prior.mutable_fixed_values()->set_var_scaling(0.1); - prior.mutable_fixed_values()->set_shape(2.0); - prior.mutable_fixed_values()->set_scale(2.0); - - auto hier = std::make_shared(); - hier->get_mutable_prior()->CopyFrom(prior); - hier->initialize(); - - std::vector> unique_values; - std::vector states; - - // Check equality before update - unique_values.push_back(hier); - for (size_t i = 1; i < 4; i++) { - unique_values.push_back(hier->clone()); - unique_values[i]->write_hypers_to_proto(&prior_out); - ASSERT_EQ(prior.fixed_values().DebugString(), - prior_out.nnig_state().DebugString()); - } - - // Check equality after update - unique_values[0]->update_hypers(states); - unique_values[0]->write_hypers_to_proto(&prior_out); - for (size_t i = 1; i < 4; i++) { - unique_values[i]->write_hypers_to_proto(&prior_out); - ASSERT_EQ(prior.fixed_values().DebugString(), - prior_out.nnig_state().DebugString()); - } -} - -TEST(hierarchies, normal_mean_prior) { - bayesmix::NNWPrior prior; - bayesmix::AlgorithmState::HierarchyHypers prior_out; - Eigen::Vector2d mu00; - mu00 << 0.0, 0.0; - auto ident = Eigen::Matrix2d::Identity(); - - prior.mutable_normal_mean_prior()->set_var_scaling(0.1); - bayesmix::to_proto( - mu00, - prior.mutable_normal_mean_prior()->mutable_mean_prior()->mutable_mean()); - bayesmix::to_proto( - ident, - prior.mutable_normal_mean_prior()->mutable_mean_prior()->mutable_var()); - prior.mutable_normal_mean_prior()->set_deg_free(2.0); - bayesmix::to_proto(ident, - prior.mutable_normal_mean_prior()->mutable_scale()); - - std::vector states(4); - for (int i = 0; i < states.size(); i++) { - double mean = 9.0 + i; - Eigen::Vector2d vec; - vec << mean, mean; - bayesmix::to_proto(vec, - states[i].mutable_multi_ls_state()->mutable_mean()); - bayesmix::to_proto(ident, - states[i].mutable_multi_ls_state()->mutable_prec()); - } - - NNWHierarchy hier; - hier.get_mutable_prior()->CopyFrom(prior); - hier.initialize(); - - hier.update_hypers(states); - hier.write_hypers_to_proto(&prior_out); - Eigen::Vector2d mean_out = bayesmix::to_eigen(prior_out.nnw_state().mean()); - std::cout << " after = " << mean_out(0) << " " << mean_out(1) - << std::endl; - assert(mu00(0) < mean_out(0) && mu00(1) < mean_out(1)); -} - -TEST(hierarchies, nxig_fixed_values) { - bayesmix::NNxIGPrior prior; - bayesmix::AlgorithmState::HierarchyHypers prior_out; - prior.mutable_fixed_values()->set_mean(5.0); - prior.mutable_fixed_values()->set_var(1.0); - prior.mutable_fixed_values()->set_shape(2.0); - prior.mutable_fixed_values()->set_scale(2.0); - - auto hier = std::make_shared(); - hier->get_mutable_prior()->CopyFrom(prior); - hier->initialize(); - - std::vector> unique_values; - std::vector states; - - // Check equality before update - unique_values.push_back(hier); - for (size_t i = 1; i < 4; i++) { - unique_values.push_back(hier->clone()); - unique_values[i]->write_hypers_to_proto(&prior_out); - ASSERT_EQ(prior.fixed_values().DebugString(), - prior_out.nnxig_state().DebugString()); - } - - // Check equality after update - unique_values[0]->update_hypers(states); - unique_values[0]->write_hypers_to_proto(&prior_out); - for (size_t i = 1; i < 4; i++) { - unique_values[i]->write_hypers_to_proto(&prior_out); - ASSERT_EQ(prior.fixed_values().DebugString(), - prior_out.nnxig_state().DebugString()); - } -} From 3d43592a7f31392b483e7f111e1e96914551b766 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 23:17:42 +0200 Subject: [PATCH 234/317] Remove test_mh_updater --- CMakeLists.txt | 8 +++---- test_mh_updater.cpp | 54 --------------------------------------------- 2 files changed, 4 insertions(+), 58 deletions(-) delete mode 100644 test_mh_updater.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5965e53df..09cbcfb7b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,7 +212,7 @@ if (NOT DISABLE_EXAMPLES) endif() # Test MH updater -add_executable(test_mh $ test_mh_updater.cpp) -target_include_directories(test_mh PUBLIC ${INCLUDE_PATHS}) -target_link_libraries(test_mh PUBLIC ${LINK_LIBRARIES}) -target_compile_options(test_mh PUBLIC ${COMPILE_OPTIONS}) +# add_executable(test_mh $ test_mh_updater.cpp) +# target_include_directories(test_mh PUBLIC ${INCLUDE_PATHS}) +# target_link_libraries(test_mh PUBLIC ${LINK_LIBRARIES}) +# target_compile_options(test_mh PUBLIC ${COMPILE_OPTIONS}) diff --git a/test_mh_updater.cpp b/test_mh_updater.cpp deleted file mode 100644 index c568e7c9b..000000000 --- a/test_mh_updater.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include - -#include -#include - -#include "lib/argparse/argparse.h" -#include "src/includes.h" - -int main() { - // Define prior hypers - bayesmix::AlgorithmState::HierarchyHypers hypers_proto; - hypers_proto.mutable_nnig_state()->set_mean(0.0); - hypers_proto.mutable_nnig_state()->set_var_scaling(0.1); - hypers_proto.mutable_nnig_state()->set_shape(4.0); - hypers_proto.mutable_nnig_state()->set_scale(3.0); - - bayesmix::NNIGPrior hier_prior; - hier_prior.mutable_fixed_values()->set_mean(0.0); - hier_prior.mutable_fixed_values()->set_var_scaling(0.1); - hier_prior.mutable_fixed_values()->set_shape(4.0); - hier_prior.mutable_fixed_values()->set_scale(3.0); - - auto prior = std::make_shared(); - prior->get_mutable_prior()->CopyFrom(hier_prior); - - // prior->set_hypers_from_proto(hypers_proto); - auto like = std::make_shared(); - auto updater = std::make_shared(0.001); - auto hier = std::make_shared(); - hier->set_likelihood(like); - hier->set_prior(prior); - hier->set_updater(updater); - std::cout << "here" << std::endl; - - hier->initialize(); - std::cout << "initializing" << std::endl; - - auto& rng = bayesmix::Rng::Instance().get(); - int ndata = 250; - Eigen::VectorXd data(ndata); - for (int i = 0; i < ndata; i++) { - data(i) = stan::math::normal_rng(5, 1.0, rng); - hier->add_datum(i, data.row(i)); - } - - int niter = 10000; - Eigen::MatrixXd chain(niter, 2); - for (int i = 0; i < niter; i++) { - hier->sample_full_cond(); - chain.row(i) = hier->get_state().get_unconstrained(); - } - - bayesmix::write_matrix_to_file(chain, "mcmc_chain_test.csv"); -} From 232fd08313420c9decbb78fffe16a9f211e65b7f Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 23:19:23 +0200 Subject: [PATCH 235/317] implemented to_proto() method --- .../likelihoods/states/multi_ls_state.h | 12 ++++++++--- .../likelihoods/states/uni_lin_reg_ls_state.h | 20 ++++++++++++------- .../likelihoods/states/uni_ls_state.h | 13 +++++++++--- src/hierarchies/priors/mnig_prior_model.cc | 16 +++++---------- src/hierarchies/priors/nig_prior_model.cc | 13 +++++------- src/hierarchies/priors/nw_prior_model.cc | 9 +-------- src/hierarchies/priors/nxig_prior_model.cc | 11 ++++------ 7 files changed, 47 insertions(+), 47 deletions(-) diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index eb97fc5ee..4211003b0 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -51,6 +51,8 @@ class MultiLS { Eigen::MatrixXd prec, prec_chol; double prec_logdet; + using ProtoState = bayesmix::AlgorithmState::ClusterState; + Eigen::VectorXd get_unconstrained() { return multi_ls_to_unconstrained(mean, prec); } @@ -68,7 +70,7 @@ class MultiLS { prec_logdet = 2 * log(diag.array()).sum(); } - void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + void set_from_proto(const ProtoState &state_) { mean = to_eigen(state_.multi_ls_state().mean()); prec = to_eigen(state_.multi_ls_state().prec()); prec_chol = to_eigen(state_.multi_ls_state().prec_chol()); @@ -76,8 +78,8 @@ class MultiLS { prec_logdet = 2 * log(diag.array()).sum(); } - bayesmix::AlgorithmState::ClusterState get_as_proto() { - bayesmix::AlgorithmState::ClusterState state; + ProtoState get_as_proto() { + ProtoState state; bayesmix::to_proto(mean, state.mutable_multi_ls_state()->mutable_mean()); bayesmix::to_proto(prec, state.mutable_multi_ls_state()->mutable_prec()); bayesmix::to_proto(prec_chol, @@ -85,6 +87,10 @@ class MultiLS { return state; } + std::shared_ptr to_proto() { + return std::make_shared(get_as_proto()); + } + double log_det_jac() { return multi_ls_log_det_jac(prec); } }; diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h index fc7b015e9..1cbe36a5f 100644 --- a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h @@ -41,6 +41,8 @@ class UniLinRegLS { Eigen::VectorXd regression_coeffs; double var; + using ProtoState = bayesmix::AlgorithmState::ClusterState; + Eigen::VectorXd get_unconstrained() { Eigen::VectorXd temp(regression_coeffs.size() + 1); temp << regression_coeffs, var; @@ -54,21 +56,25 @@ class UniLinRegLS { var = temp(dim); } - void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + void set_from_proto(const ProtoState &state_) { regression_coeffs = bayesmix::to_eigen(state_.lin_reg_uni_ls_state().regression_coeffs()); var = state_.lin_reg_uni_ls_state().var(); } - bayesmix::AlgorithmState::ClusterState get_as_proto() { - bayesmix::LinRegUniLSState out; - bayesmix::to_proto(regression_coeffs, out.mutable_regression_coeffs()); - out.set_var(var); - bayesmix::AlgorithmState::ClusterState state; - state.mutable_lin_reg_uni_ls_state()->CopyFrom(out); + ProtoState get_as_proto() { + ProtoState state; + bayesmix::to_proto( + regression_coeffs, + state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()); + state.mutable_lin_reg_uni_ls_state()->set_var(var); return state; } + std::shared_ptr to_proto() { + return std::make_shared(get_as_proto()); + } + double log_det_jac() { Eigen::VectorXd temp(regression_coeffs.size() + 1); temp << regression_coeffs, var; diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h index a347c54cd..a37be6280 100644 --- a/src/hierarchies/likelihoods/states/uni_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -1,6 +1,7 @@ #ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ #define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_UNI_LS_STATE_H_ +#include #include #include "algorithm_state.pb.h" @@ -35,6 +36,8 @@ class UniLS { public: double mean, var; + using ProtoState = bayesmix::AlgorithmState::ClusterState; + Eigen::VectorXd get_unconstrained() { Eigen::VectorXd temp(2); temp << mean, var; @@ -47,18 +50,22 @@ class UniLS { var = temp(1); } - void set_from_proto(const bayesmix::AlgorithmState::ClusterState &state_) { + void set_from_proto(const ProtoState &state_) { mean = state_.uni_ls_state().mean(); var = state_.uni_ls_state().var(); } - bayesmix::AlgorithmState::ClusterState get_as_proto() { - bayesmix::AlgorithmState::ClusterState state; + ProtoState get_as_proto() { + ProtoState state; state.mutable_uni_ls_state()->set_mean(mean); state.mutable_uni_ls_state()->set_var(var); return state; } + std::shared_ptr to_proto() { + return std::make_shared(get_as_proto()); + } + double log_det_jac() { Eigen::VectorXd temp(2); temp << mean, var; diff --git a/src/hierarchies/priors/mnig_prior_model.cc b/src/hierarchies/priors/mnig_prior_model.cc index 215647780..f65732080 100644 --- a/src/hierarchies/priors/mnig_prior_model.cc +++ b/src/hierarchies/priors/mnig_prior_model.cc @@ -19,17 +19,11 @@ std::shared_ptr MNIGPriorModel::sample( Eigen::VectorXd mean = bayesmix::to_eigen(params.mean()); Eigen::MatrixXd var_scaling = bayesmix::to_eigen(params.var_scaling()); - double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); - Eigen::VectorXd regression_coeffs = - stan::math::multi_normal_prec_rng(mean, var_scaling / var, rng); - - bayesmix::AlgorithmState::ClusterState state; - bayesmix::to_proto( - regression_coeffs, - state.mutable_lin_reg_uni_ls_state()->mutable_regression_coeffs()); - state.mutable_lin_reg_uni_ls_state()->set_var(var); - - return std::make_shared(state); + State::UniLinRegLS out; + out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + out.regression_coeffs = + stan::math::multi_normal_prec_rng(mean, var_scaling / out.var, rng); + return out.to_proto(); } void MNIGPriorModel::update_hypers( diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index 248372f49..31b5ea7be 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -90,14 +90,11 @@ std::shared_ptr NIGPriorModel::sample( auto params = (hier_hypers) ? hier_hypers->nnig_state() : get_hypers_proto()->nnig_state(); - double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); - double mean = stan::math::normal_rng(params.mean(), - sqrt(var / params.var_scaling()), rng); - - bayesmix::AlgorithmState::ClusterState state; - state.mutable_uni_ls_state()->set_mean(mean); - state.mutable_uni_ls_state()->set_var(var); - return std::make_shared(state); + State::UniLS out; + out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + out.mean = stan::math::normal_rng(params.mean(), + sqrt(out.var / params.var_scaling()), rng); + return out.to_proto(); } void NIGPriorModel::update_hypers( diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc index 50fd2c0a0..342d2de4e 100644 --- a/src/hierarchies/priors/nw_prior_model.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -137,14 +137,7 @@ std::shared_ptr NWPriorModel::sample( out.mean = stan::math::multi_normal_prec_rng( mean, tau_new * params.var_scaling(), rng); write_prec_to_state(tau_new, &out); - - // Make output state - bayesmix::AlgorithmState::ClusterState state; - bayesmix::to_proto(out.mean, state.mutable_multi_ls_state()->mutable_mean()); - bayesmix::to_proto(out.prec, state.mutable_multi_ls_state()->mutable_prec()); - bayesmix::to_proto(out.prec_chol, - state.mutable_multi_ls_state()->mutable_prec_chol()); - return std::make_shared(state); + return out.to_proto(); }; void NWPriorModel::update_hypers( diff --git a/src/hierarchies/priors/nxig_prior_model.cc b/src/hierarchies/priors/nxig_prior_model.cc index 52dacc703..dc13f70d1 100644 --- a/src/hierarchies/priors/nxig_prior_model.cc +++ b/src/hierarchies/priors/nxig_prior_model.cc @@ -36,13 +36,10 @@ std::shared_ptr NxIGPriorModel::sample( auto params = (hier_hypers) ? hier_hypers->nnxig_state() : get_hypers_proto()->nnxig_state(); - double var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); - double mean = stan::math::normal_rng(params.mean(), sqrt(params.var()), rng); - - bayesmix::AlgorithmState::ClusterState state; - state.mutable_uni_ls_state()->set_mean(mean); - state.mutable_uni_ls_state()->set_var(var); - return std::make_shared(state); + State::UniLS out; + out.mean = stan::math::normal_rng(params.mean(), sqrt(params.var()), rng); + out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); + return out.to_proto(); }; void NxIGPriorModel::update_hypers( From 8c5ba27739f88826412d230be643d2e889aed83a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 23:19:58 +0200 Subject: [PATCH 236/317] trivial changes --- src/hierarchies/likelihoods/states/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt index 89da41e53..2ccfbf343 100644 --- a/src/hierarchies/likelihoods/states/CMakeLists.txt +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -1,7 +1,7 @@ target_sources(bayesmix PUBLIC + includes.h uni_ls_state.h multi_ls_state.h uni_lin_reg_ls_state.h fa_state.h - includes.h ) From 32d6528485c01191c5b6256b55234816fcca5ae7 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 28 Mar 2022 23:20:31 +0200 Subject: [PATCH 237/317] Implemented covariates_getter callable class --- src/hierarchies/base_hierarchy.h | 87 +++++++++++-------- src/hierarchies/likelihoods/base_likelihood.h | 48 ++++++---- src/utils/CMakeLists.txt | 1 + src/utils/covariates_getter.h | 27 ++++++ 4 files changed, 108 insertions(+), 55 deletions(-) create mode 100644 src/utils/covariates_getter.h diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 2bd9bcda3..2c9ceaa47 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -11,6 +11,7 @@ #include "abstract_hierarchy.h" #include "algorithm_state.pb.h" #include "hierarchy_id.pb.h" +#include "src/utils/covariates_getter.h" #include "src/utils/rng.h" #include "updaters/target_lpdf_unconstrained.h" @@ -200,25 +201,31 @@ class BaseHierarchy : public AbstractHierarchy { const Eigen::MatrixXd &covariates /*= Eigen::MatrixXd(0, 0)*/) const override { Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->conditional_pred_lpdf( - data.row(i), covariates.row(i)); - } + covariates_getter cov_getter(covariates); + for (int i = 0; i < data.rows(); i++) { + lpdf(i) = static_cast(this)->conditional_pred_lpdf( + data.row(i), cov_getter(i)); } + + // if (covariates.cols() == 0) { + // // Pass null value as covariate + // for (int i = 0; i < data.rows(); i++) { + // lpdf(i) = static_cast(this)->conditional_pred_lpdf( + // data.row(i), Eigen::RowVectorXd(0)); + // } + // } else if (covariates.rows() == 1) { + // // Use unique covariate + // for (int i = 0; i < data.rows(); i++) { + // lpdf(i) = static_cast(this)->conditional_pred_lpdf( + // data.row(i), covariates.row(0)); + // } + // } else { + // // Use different covariates + // for (int i = 0; i < data.rows(); i++) { + // lpdf(i) = static_cast(this)->conditional_pred_lpdf( + // data.row(i), covariates.row(i)); + // } + // } return lpdf; } @@ -239,25 +246,33 @@ class BaseHierarchy : public AbstractHierarchy { const Eigen::MatrixXd &covariates = Eigen::MatrixXd(0, 0)) override { like->clear_data(); like->clear_summary_statistics(); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - static_cast(this)->add_datum(i, data.row(i), false, - covariates.row(i)); - } + + covariates_getter cov_getter(covariates); + for (int i = 0; i < data.rows(); i++) { + static_cast(this)->add_datum(i, data.row(i), false, + cov_getter(i)); } + + // if (covariates.cols() == 0) { + // // Pass null value as covariate + // for (int i = 0; i < data.rows(); i++) { + // static_cast(this)->add_datum(i, data.row(i), false, + // Eigen::RowVectorXd(0)); + // } + // } else if (covariates.rows() == 1) { + // // Use unique covariate + // for (int i = 0; i < data.rows(); i++) { + // static_cast(this)->add_datum(i, data.row(i), false, + // covariates.row(0)); + // } + // } else { + // // Use different covariates + // for (int i = 0; i < data.rows(); i++) { + // static_cast(this)->add_datum(i, data.row(i), false, + // covariates.row(i)); + // } + // } + static_cast(this)->sample_full_cond(true); }; diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 591da527a..cb357c55d 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -11,6 +11,7 @@ #include "abstract_likelihood.h" #include "algorithm_state.pb.h" #include "likelihood_internal.h" +#include "src/utils/covariates_getter.h" template class BaseLikelihood : public AbstractLikelihood { @@ -195,26 +196,35 @@ template Eigen::VectorXd BaseLikelihood::lpdf_grid( const Eigen::MatrixXd &data, const Eigen::MatrixXd &covariates) const { Eigen::VectorXd lpdf(data.rows()); - if (covariates.cols() == 0) { - // Pass null value as covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->lpdf( - data.row(i), Eigen::RowVectorXd(0)); - } - } else if (covariates.rows() == 1) { - // Use unique covariate - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->lpdf(data.row(i), - covariates.row(0)); - } - } else { - // Use different covariates - for (int i = 0; i < data.rows(); i++) { - lpdf(i) = static_cast(this)->lpdf(data.row(i), - covariates.row(i)); - } - } + covariates_getter cov_getter(covariates); + for (int i = 0; i < data.rows(); i++) + lpdf(i) = + static_cast(this)->lpdf(data.row(i), cov_getter(i)); + return lpdf; } +/* OLD STUFF THAT WILL BE REMOVED */ +// Eigen::VectorXd lpdf(data.rows()); + +// if (covariates.cols() == 0) { +// // Pass null value as covariate +// for (int i = 0; i < data.rows(); i++) { +// lpdf(i) = static_cast(this)->lpdf( +// data.row(i), Eigen::RowVectorXd(0)); +// } +// } else if (covariates.rows() == 1) { +// // Use unique covariate +// for (int i = 0; i < data.rows(); i++) { +// lpdf(i) = static_cast(this)->lpdf(data.row(i), +// covariates.row(0)); +// } +// } else { +// // Use different covariates +// for (int i = 0; i < data.rows(); i++) { +// lpdf(i) = static_cast(this)->lpdf(data.row(i), +// covariates.row(i)); +// } +// } + #endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt index dfec7c2a9..c891c24cf 100644 --- a/src/utils/CMakeLists.txt +++ b/src/utils/CMakeLists.txt @@ -15,4 +15,5 @@ target_sources(bayesmix rng.h testing_utils.h testing_utils.cc + covariates_getter.h ) diff --git a/src/utils/covariates_getter.h b/src/utils/covariates_getter.h new file mode 100644 index 000000000..46adb9802 --- /dev/null +++ b/src/utils/covariates_getter.h @@ -0,0 +1,27 @@ +#ifndef BAYESMIX_SRC_UTILS_COVARIATES_GETTER_H +#define BAYESMIX_SRC_UTILS_COVARIATES_GETTER_H + +#include +// #include "src/hierarchies/likelihoods/abstract_likelihood.h" +// #include "src/hierarchies/priors/abstract_prior_model.h" + +class covariates_getter { + protected: + const Eigen::MatrixXd* covariates; + + public: + covariates_getter(const Eigen::MatrixXd& covariates_) + : covariates(&covariates_){}; + + Eigen::RowVectorXd operator()(const size_t& i) const { + if (covariates->cols() == 0) { + return Eigen::RowVectorXd(0); + } else if (covariates->rows() == 1) { + return covariates->row(0); + } else { + return covariates->row(i); + } + }; +}; + +#endif // BAYESMIX_SRC_UTILS_COVARIATES_GETTER_H From 87963a82970973cfa3a0abde7bf235dc9c66430e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 29 Mar 2022 17:44:40 +0200 Subject: [PATCH 238/317] Changed fake_prior field (NEEDS APPROVAL) --- src/proto/algorithm_state.proto | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/proto/algorithm_state.proto b/src/proto/algorithm_state.proto index 221e0ed4e..2db4cbc47 100644 --- a/src/proto/algorithm_state.proto +++ b/src/proto/algorithm_state.proto @@ -38,7 +38,8 @@ message AlgorithmState { message HierarchyHypers { // Current values of the Hyperparameters of the Hierarchy oneof val { - EmptyPrior fake_prior = 1; + Vector fake_prior = 1; + //EmptyPrior fake_prior = 1; NIGDistribution nnig_state = 2; NWDistribution nnw_state = 3; MultiNormalIGDistribution lin_reg_uni_state = 4; From 9596664ccdb07a8d96734f91d1b58e9c18c3e7e8 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 29 Mar 2022 17:47:09 +0200 Subject: [PATCH 239/317] Ignore .old folders --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 7134946e4..a4fdf64a5 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ cmake-build-debug/ # .old folders src/hierarchies/updaters/.old/ test/.old/ +examples/gamma_hierarchy/.old/ From f58f9d8031abd87c4f5e6d056787fe64e2dca85d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 29 Mar 2022 17:47:30 +0200 Subject: [PATCH 240/317] GammaGamma example is now working --- examples/CMakeLists.txt | 7 +- examples/gamma_hierarchy/gamma_gamma_hier.h | 140 ------------------ examples/gamma_hierarchy/gamma_likelihood.h | 86 +++++++++++ examples/gamma_hierarchy/gamma_prior_model.h | 113 ++++++++++++++ .../gamma_hierarchy/gammagamma_hierarchy.h | 47 ++++++ examples/gamma_hierarchy/gammagamma_updater.h | 49 ++++++ examples/gamma_hierarchy/run_gamma_gamma.cc | 2 +- 7 files changed, 301 insertions(+), 143 deletions(-) delete mode 100644 examples/gamma_hierarchy/gamma_gamma_hier.h create mode 100644 examples/gamma_hierarchy/gamma_likelihood.h create mode 100644 examples/gamma_hierarchy/gamma_prior_model.h create mode 100644 examples/gamma_hierarchy/gammagamma_hierarchy.h create mode 100644 examples/gamma_hierarchy/gammagamma_updater.h diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 733662a26..09934a9f0 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,9 +1,12 @@ cmake_minimum_required(VERSION 3.13.0) project(examples_bayesmix) -add_executable(run_gamma $ +add_executable(run_gamma $ gamma_hierarchy/run_gamma_gamma.cc - gamma_hierarchy/gamma_gamma_hier.h + gamma_hierarchy/gammagamma_hierarchy.h + gamma_hierarchy/gamma_likelihood.h + gamma_hierarchy/gamma_prior_model.h + gamma_hierarchy/gammagamma_updater.h ) target_include_directories(run_gamma PUBLIC ${INCLUDE_PATHS}) diff --git a/examples/gamma_hierarchy/gamma_gamma_hier.h b/examples/gamma_hierarchy/gamma_gamma_hier.h deleted file mode 100644 index 7e7931808..000000000 --- a/examples/gamma_hierarchy/gamma_gamma_hier.h +++ /dev/null @@ -1,140 +0,0 @@ -#ifndef BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_ -#define BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_ - -#include -#include -#include - -#include -#include -#include - -#include "hierarchy_prior.pb.h" - -namespace GammaGamma { -//! Custom container for State values -struct State { - double rate; -}; - -//! Custom container for Hyperparameters values -struct Hyperparams { - double shape, rate_alpha, rate_beta; -}; -}; // namespace GammaGamma - -class GammaGammaHierarchy - : public ConjugateHierarchy { - public: - GammaGammaHierarchy(const double shape, const double rate_alpha, - const double rate_beta) - : shape(shape), rate_alpha(rate_alpha), rate_beta(rate_beta) { - create_empty_prior(); - } - ~GammaGammaHierarchy() = default; - - double like_lpdf(const Eigen::RowVectorXd &datum) const override { - return stan::math::gamma_lpdf(datum(0), hypers->shape, state.rate); - } - - double marg_lpdf( - const GammaGamma::Hyperparams ¶ms, const Eigen::RowVectorXd &datum, - const Eigen::RowVectorXd &covariate = Eigen::RowVectorXd(0)) const { - throw std::runtime_error("marg_lpdf() not implemented"); - return 0; - } - - GammaGamma::State draw(const GammaGamma::Hyperparams ¶ms) { - return GammaGamma::State{stan::math::gamma_rng( - params.rate_alpha, params.rate_beta, bayesmix::Rng::Instance().get())}; - } - - void update_summary_statistics(const Eigen::RowVectorXd &datum, - const bool add) { - if (add) { - data_sum += datum(0); - ndata += 1; - } else { - data_sum -= datum(0); - ndata -= 1; - } - } - - //! Computes and return posterior hypers given data currently in this cluster - GammaGamma::Hyperparams compute_posterior_hypers() { - GammaGamma::Hyperparams out; - out.shape = hypers->shape; - out.rate_alpha = hypers->rate_alpha + hypers->shape * ndata; - out.rate_beta = hypers->rate_beta + data_sum; - return out; - } - - void initialize_state() override { - state.rate = hypers->rate_alpha / hypers->rate_beta; - } - - void initialize_hypers() { - hypers->shape = shape; - hypers->rate_alpha = rate_alpha; - hypers->rate_beta = rate_beta; - } - - //! Removes every data point from this cluster - void clear_summary_statistics() { - data_sum = 0; - ndata = 0; - } - - bool is_multivariate() const override { return false; } - - void set_state_from_proto(const google::protobuf::Message &state_) override { - auto &statecast = google::protobuf::internal::down_cast< - const bayesmix::AlgorithmState::ClusterState &>(state_); - state.rate = statecast.general_state().data()[0]; - set_card(statecast.cardinality()); - } - - std::shared_ptr get_state_proto() - const override { - bayesmix::Vector state_; - state_.mutable_data()->Add(state.rate); - - auto out = std::make_unique(); - out->mutable_general_state()->CopyFrom(state_); - return out; - } - - void update_hypers(const std::vector - &states) override { - return; - } - - void write_hypers_to_proto( - google::protobuf::Message *const out) const override { - return; - } - - void set_hypers_from_proto( - const google::protobuf::Message &state_) override { - return; - } - - std::shared_ptr get_hypers_proto() - const override { - return nullptr; - } - - bayesmix::HierarchyId get_id() const override { - return bayesmix::HierarchyId::UNKNOWN_HIERARCHY; - } - - protected: - double data_sum = 0; - int ndata = 0; - - double shape, rate_alpha, rate_beta; -}; - -#endif // BAYESMIX_HIERARCHIES_GAMMAGAMMA_HIERARCHY_H_ diff --git a/examples/gamma_hierarchy/gamma_likelihood.h b/examples/gamma_hierarchy/gamma_likelihood.h new file mode 100644 index 000000000..3f8aac15a --- /dev/null +++ b/examples/gamma_hierarchy/gamma_likelihood.h @@ -0,0 +1,86 @@ +#ifndef BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "src/hierarchies/likelihoods/base_likelihood.h" + +namespace State { +class Gamma { + public: + double shape, rate; +}; +} // namespace State + +class GammaLikelihood : public BaseLikelihood { + public: + GammaLikelihood() = default; + ~GammaLikelihood() = default; + bool is_multivariate() const override { return false; }; + bool is_dependent() const override { return false; }; + void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) override; + void clear_summary_statistics() override; + + // Getters and Setters + int get_ndata() const { return ndata; }; + double get_shape() const { return state.shape; }; + double get_data_sum() const { return data_sum; }; + std::shared_ptr get_state_proto() + const override; + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + + //! Sum of data in the cluster + double data_sum = 0; + //! number of data in the cluster + int ndata = 0; +}; + +/* DEFINITIONS */ +void GammaLikelihood::set_state_from_proto( + const google::protobuf::Message &state_, bool update_card) { + auto &statecast = downcast_state(state_); + state.rate = statecast.general_state().data()[0]; + if (update_card) set_card(statecast.cardinality()); +} + +void GammaLikelihood::clear_summary_statistics() { + data_sum = 0; + ndata = 0; +} + +std::shared_ptr +GammaLikelihood::get_state_proto() const { + bayesmix::Vector state_; + state_.mutable_data()->Add(state.shape); + state_.mutable_data()->Add(state.rate); + + auto out = std::make_shared(); + out->mutable_general_state()->CopyFrom(state_); + return out; +} + +double GammaLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::gamma_lpdf(datum(0), state.shape, state.rate); +} + +void GammaLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + bool add) { + if (add) { + data_sum += datum(0); + ndata += 1; + } else { + data_sum -= datum(0); + ndata -= 1; + } +} + +#endif // BAYESMIX_HIERARCHIES_GAMMA_LIKELIHOOD_H_ diff --git a/examples/gamma_hierarchy/gamma_prior_model.h b/examples/gamma_hierarchy/gamma_prior_model.h new file mode 100644 index 000000000..3c226df33 --- /dev/null +++ b/examples/gamma_hierarchy/gamma_prior_model.h @@ -0,0 +1,113 @@ +#ifndef BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_ + +#include +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "hierarchy_prior.pb.h" +#include "src/hierarchies/priors/base_prior_model.h" +#include "src/utils/rng.h" + +namespace Hyperparams { +struct Gamma { + double rate_alpha, rate_beta; +}; +} // namespace Hyperparams + +class GammaPriorModel + : public BasePriorModel { + public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + + GammaPriorModel(double shape_ = -1, double rate_alpha_ = -1, + double rate_beta_ = -1); + ~GammaPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + std::shared_ptr sample( + ProtoHypersPtr hier_hypers = nullptr) override; + + void update_hypers(const std::vector + &states) override { + return; + }; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + ProtoHypersPtr get_hypers_proto() const override; + double get_shape() const { return shape; }; + + protected: + double shape, rate_alpha, rate_beta; + void initialize_hypers() override; +}; + +/* DEFINITIONS */ +GammaPriorModel::GammaPriorModel(double shape_, double rate_alpha_, + double rate_beta_) + : shape(shape_), rate_alpha(rate_alpha_), rate_beta(rate_beta_) { + create_empty_prior(); +}; + +double GammaPriorModel::lpdf(const google::protobuf::Message &state_) { + double rate = downcast_state(state_).general_state().data()[1]; + return stan::math::gamma_lpdf(rate, hypers->rate_alpha, hypers->rate_beta); +} + +std::shared_ptr GammaPriorModel::sample( + ProtoHypersPtr hier_hypers) { + auto &rng = bayesmix::Rng::Instance().get(); + + auto params = (hier_hypers) ? hier_hypers->fake_prior() + : get_hypers_proto()->fake_prior(); + double rate_alpha = params.data()[0]; + double rate_beta = params.data()[1]; + double new_rate = stan::math::gamma_rng(rate_alpha, rate_beta, rng); + + bayesmix::AlgorithmState::ClusterState out; + out.mutable_general_state()->mutable_data()->Add(shape); + out.mutable_general_state()->mutable_data()->Add(new_rate); + return std::make_shared(out); +} + +void GammaPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_); + hypers->rate_alpha = hyperscast.fake_prior().data()[0]; + hypers->rate_beta = hyperscast.fake_prior().data()[1]; +}; + +GammaPriorModel::ProtoHypersPtr GammaPriorModel::get_hypers_proto() const { + bayesmix::Vector hypers_; + hypers_.mutable_data()->Add(hypers->rate_alpha); + hypers_.mutable_data()->Add(hypers->rate_beta); + + ProtoHypersPtr out = std::make_shared(); + out->mutable_fake_prior()->CopyFrom(hypers_); + return out; +}; + +void GammaPriorModel::initialize_hypers() { + hypers->rate_alpha = rate_alpha; + hypers->rate_beta = rate_beta; + + // Checks + if (shape <= 0) { + throw std::runtime_error("shape must be positive"); + } + if (rate_alpha <= 0) { + throw std::runtime_error("rate_alpha must be positive"); + } + if (rate_beta <= 0) { + throw std::runtime_error("rate_beta must be positive"); + } +} + +#endif // BAYESMIX_HIERARCHIES_GAMMA_PRIOR_MODEL_H_ diff --git a/examples/gamma_hierarchy/gammagamma_hierarchy.h b/examples/gamma_hierarchy/gammagamma_hierarchy.h new file mode 100644 index 000000000..6fe135533 --- /dev/null +++ b/examples/gamma_hierarchy/gammagamma_hierarchy.h @@ -0,0 +1,47 @@ +#ifndef BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_ + +#include "gamma_likelihood.h" +#include "gamma_prior_model.h" +#include "gammagamma_updater.h" +#include "hierarchy_id.pb.h" +#include "src/hierarchies/base_hierarchy.h" + +class GammaGammaHierarchy + : public BaseHierarchy { + public: + GammaGammaHierarchy(double shape_, double rate_alpha_, double rate_beta_) { + auto prior = + std::make_shared(shape_, rate_alpha_, rate_beta_); + set_prior(prior); + }; + ~GammaGammaHierarchy() = default; + + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::UNKNOWN_HIERARCHY; + } + + void set_default_updater() { + updater = std::make_shared(); + } + + void initialize_state() override { + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::Gamma state; + state.shape = prior->get_shape(); + state.rate = hypers.rate_alpha / hypers.rate_beta; + like->set_state(state); + }; + + double marg_lpdf(ProtoHypersPtr hier_params, + const Eigen::RowVectorXd &datum) const override { + throw( + std::runtime_error("marg_lpdf() not implemented for this hierarchy")); + return 0; + } +}; + +#endif // BAYESMIX_HIERARCHIES_GAMMA_GAMMA_HIERARCHY_H_ diff --git a/examples/gamma_hierarchy/gammagamma_updater.h b/examples/gamma_hierarchy/gammagamma_updater.h new file mode 100644 index 000000000..310ac65da --- /dev/null +++ b/examples/gamma_hierarchy/gammagamma_updater.h @@ -0,0 +1,49 @@ +#ifndef BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_ +#define BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_ + +#include "gamma_likelihood.h" +#include "gamma_prior_model.h" +#include "src/hierarchies/updaters/semi_conjugate_updater.h" + +class GammaGammaUpdater + : public SemiConjugateUpdater { + public: + GammaGammaUpdater() = default; + ~GammaGammaUpdater() = default; + + bool is_conjugate() const override { return true; }; + + ProtoHypersPtr compute_posterior_hypers(AbstractLikelihood& like, + AbstractPriorModel& prior) override; +}; + +/* DEFINITIONS */ +AbstractUpdater::ProtoHypersPtr GammaGammaUpdater::compute_posterior_hypers( + AbstractLikelihood& like, AbstractPriorModel& prior) { + // Likelihood and Prior downcast + auto& likecast = downcast_likelihood(like); + auto& priorcast = downcast_prior(prior); + + // Getting required quantities from likelihood and prior + int card = likecast.get_card(); + double data_sum = likecast.get_data_sum(); + double ndata = likecast.get_ndata(); + double shape = priorcast.get_shape(); + auto hypers = priorcast.get_hypers(); + + // No update possible + if (card == 0) { + return priorcast.get_hypers_proto(); + } + // Compute posterior hyperparameters + double rate_alpha_new = hypers.rate_alpha + shape * ndata; + double rate_beta_new = hypers.rate_beta + data_sum; + + // Proto conversion + ProtoHypers out; + out.mutable_fake_prior()->mutable_data()->Add(rate_alpha_new); + out.mutable_fake_prior()->mutable_data()->Add(rate_beta_new); + return std::make_shared(out); +} + +#endif // BAYESMIX_HIERARCHIES_GAMMA_GAMMA_UPDATER_H_ diff --git a/examples/gamma_hierarchy/run_gamma_gamma.cc b/examples/gamma_hierarchy/run_gamma_gamma.cc index f6502b72d..d8e5cf0f5 100644 --- a/examples/gamma_hierarchy/run_gamma_gamma.cc +++ b/examples/gamma_hierarchy/run_gamma_gamma.cc @@ -1,6 +1,6 @@ #include -#include "gamma_gamma_hier.h" +#include "gammagamma_hierarchy.h" #include "src/includes.h" Eigen::MatrixXd simulate_data(const unsigned int ndata) { From a4296d18bdfc5095194407cc8d3f15a9fe58421d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 29 Mar 2022 18:15:18 +0200 Subject: [PATCH 241/317] Cleaned code --- src/hierarchies/likelihoods/states/multi_ls_state.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index 4211003b0..1ef26f192 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -35,7 +35,6 @@ multi_ls_to_constrained(Eigen::Matrix in) { return std::make_tuple(mean, prec); } -// SEE GitHub for tests template T multi_ls_log_det_jac( Eigen::Matrix prec_constrained) { From 977d0171b71b8beea76c8e6576086fa835a3a9bb Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 29 Mar 2022 18:19:41 +0200 Subject: [PATCH 242/317] Comments on SFINAE exception handling --- src/hierarchies/likelihoods/likelihood_internal.h | 8 ++++++++ src/hierarchies/priors/prior_model_internal.h | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/src/hierarchies/likelihoods/likelihood_internal.h b/src/hierarchies/likelihoods/likelihood_internal.h index 28f0be6ed..364a43ca9 100644 --- a/src/hierarchies/likelihoods/likelihood_internal.h +++ b/src/hierarchies/likelihoods/likelihood_internal.h @@ -1,6 +1,14 @@ #ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_LIKELIHOOD_INTERNAL_H_ #define BAYESMIX_HIERARCHIES_LIKELIHOODS_LIKELIHOOD_INTERNAL_H_ +//! These functions exploit SFINAE to manage exception handling in all methods +//! required only if end user wants to rely on Metropolis-like updaters. SFINAE +//! (Substitution Failure Is Not An Error) is a C++ rule that applies during +//! overload resolution of function templates: When substituting the explicitly +//! specified or deduced type for the template parameter fails, the +//! specialization is discarded from the overload set instead of causing a +//! compile error. This feature is used in template metaprogramming. + namespace internal { /* SFINAE for cluster_lpdf_from_unconstrained() */ diff --git a/src/hierarchies/priors/prior_model_internal.h b/src/hierarchies/priors/prior_model_internal.h index 5272038ea..e1de1f5ff 100644 --- a/src/hierarchies/priors/prior_model_internal.h +++ b/src/hierarchies/priors/prior_model_internal.h @@ -1,6 +1,14 @@ #ifndef BAYESMIX_HIERARCHIES_PRIORS_PRIOR_MODEL_INTERNAL_H_ #define BAYESMIX_HIERARCHIES_PRIORS_PRIOR_MODEL_INTERNAL_H_ +//! These functions exploit SFINAE to manage exception handling in all methods +//! required only if end user wants to rely on Metropolis-like updaters. SFINAE +//! (Substitution Failure Is Not An Error) is a C++ rule that applies during +//! overload resolution of function templates: When substituting the explicitly +//! specified or deduced type for the template parameter fails, the +//! specialization is discarded from the overload set instead of causing a +//! compile error. This feature is used in template metaprogramming. + namespace internal { template From ad2a9cc489c0a7c233a98c0b10e4c6ce74c8f563 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 11 Apr 2022 16:47:41 +0200 Subject: [PATCH 243/317] Code cleaned --- src/hierarchies/base_hierarchy.h | 41 ------------------- src/hierarchies/likelihoods/base_likelihood.h | 23 ----------- 2 files changed, 64 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 2c9ceaa47..4e9591a5a 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -206,26 +206,6 @@ class BaseHierarchy : public AbstractHierarchy { lpdf(i) = static_cast(this)->conditional_pred_lpdf( data.row(i), cov_getter(i)); } - - // if (covariates.cols() == 0) { - // // Pass null value as covariate - // for (int i = 0; i < data.rows(); i++) { - // lpdf(i) = static_cast(this)->conditional_pred_lpdf( - // data.row(i), Eigen::RowVectorXd(0)); - // } - // } else if (covariates.rows() == 1) { - // // Use unique covariate - // for (int i = 0; i < data.rows(); i++) { - // lpdf(i) = static_cast(this)->conditional_pred_lpdf( - // data.row(i), covariates.row(0)); - // } - // } else { - // // Use different covariates - // for (int i = 0; i < data.rows(); i++) { - // lpdf(i) = static_cast(this)->conditional_pred_lpdf( - // data.row(i), covariates.row(i)); - // } - // } return lpdf; } @@ -252,27 +232,6 @@ class BaseHierarchy : public AbstractHierarchy { static_cast(this)->add_datum(i, data.row(i), false, cov_getter(i)); } - - // if (covariates.cols() == 0) { - // // Pass null value as covariate - // for (int i = 0; i < data.rows(); i++) { - // static_cast(this)->add_datum(i, data.row(i), false, - // Eigen::RowVectorXd(0)); - // } - // } else if (covariates.rows() == 1) { - // // Use unique covariate - // for (int i = 0; i < data.rows(); i++) { - // static_cast(this)->add_datum(i, data.row(i), false, - // covariates.row(0)); - // } - // } else { - // // Use different covariates - // for (int i = 0; i < data.rows(); i++) { - // static_cast(this)->add_datum(i, data.row(i), false, - // covariates.row(i)); - // } - // } - static_cast(this)->sample_full_cond(true); }; diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index cb357c55d..1c67f6ee7 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -204,27 +204,4 @@ Eigen::VectorXd BaseLikelihood::lpdf_grid( return lpdf; } -/* OLD STUFF THAT WILL BE REMOVED */ -// Eigen::VectorXd lpdf(data.rows()); - -// if (covariates.cols() == 0) { -// // Pass null value as covariate -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->lpdf( -// data.row(i), Eigen::RowVectorXd(0)); -// } -// } else if (covariates.rows() == 1) { -// // Use unique covariate -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->lpdf(data.row(i), -// covariates.row(0)); -// } -// } else { -// // Use different covariates -// for (int i = 0; i < data.rows(); i++) { -// lpdf(i) = static_cast(this)->lpdf(data.row(i), -// covariates.row(i)); -// } -// } - #endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_BASE_LIKELIHOOD_H_ From 480be303f1afa59146779300e8b2649e47cf58b7 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 11 Apr 2022 17:26:55 +0200 Subject: [PATCH 244/317] Add set_fa_hyperparams_from_data function --- src/hierarchies/fa_hierarchy.h | 42 ++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index 063be1d45..136761273 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -38,4 +38,46 @@ class FAHierarchy } }; +inline void set_fa_hyperparams_from_data(FAHierarchy* hier) { + auto dataset_ptr = + std::static_pointer_cast(hier->get_likelihood()) + ->get_dataset(); + auto hypers = + std::static_pointer_cast(hier->get_prior())->get_hypers(); + unsigned int dim = + std::static_pointer_cast(hier->get_likelihood()) + ->get_dim(); + + // Automatic initialization + if (dim == 0) { + hypers.mutilde = dataset_ptr->colwise().mean(); + dim = hypers.mutilde.size(); + } + if (hypers.beta.size() == 0) { + Eigen::MatrixXd centered = + dataset_ptr->rowwise() - dataset_ptr->colwise().mean(); + auto cov_llt = + ((centered.transpose() * centered) / double(dataset_ptr->rows() - 1.)) + .llt(); + Eigen::MatrixXd precision_matrix( + cov_llt.solve(Eigen::MatrixXd::Identity(dim, dim))); + hypers.beta = + (hypers.alpha0 - 1) * precision_matrix.diagonal().cwiseInverse(); + if (hypers.alpha0 == 1) { + throw std::invalid_argument( + "Scale parameter must be different than 1 when automatic " + "initialization is used"); + } + } + + bayesmix::AlgorithmState::HierarchyHypers state; + bayesmix::to_proto(hypers.mutilde, + state.mutable_fa_state()->mutable_mutilde()); + bayesmix::to_proto(hypers.beta, state.mutable_fa_state()->mutable_beta()); + state.mutable_fa_state()->set_alpha0(hypers.alpha0); + state.mutable_fa_state()->set_phi(hypers.phi); + state.mutable_fa_state()->set_q(hypers.q); + hier->get_prior()->set_hypers_from_proto(state); +}; + #endif // BAYESMIX_HIERARCHIES_FA_HIERARCHY_H_ From 375dfa2dd7060547899990cc6a483e1058490bf0 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 11 Apr 2022 17:27:22 +0200 Subject: [PATCH 245/317] Cleaned code --- src/hierarchies/priors/fa_prior_model.cc | 25 ------------------------ 1 file changed, 25 deletions(-) diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index 99271094a..5f4b97a66 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -131,28 +131,3 @@ void FAPriorModel::initialize_hypers() { throw std::invalid_argument("Unrecognized hierarchy prior"); } } - -// TODO -/* -// Automatic initialization -if (dim == 0) { - hypers->mutilde = dataset_ptr->colwise().mean(); - dim = hypers->mutilde.size(); -} -if (hypers->beta.size() == 0) { - Eigen::MatrixXd centered = - dataset_ptr->rowwise() - dataset_ptr->colwise().mean(); - auto cov_llt = ((centered.transpose() * centered) / - double(dataset_ptr->rows() - 1.)) - .llt(); - Eigen::MatrixXd precision_matrix( - cov_llt.solve(Eigen::MatrixXd::Identity(dim, dim))); - hypers->beta = - (hypers->alpha0 - 1) * precision_matrix.diagonal().cwiseInverse(); - if (hypers->alpha0 == 1) { - throw std::invalid_argument( - "Scale parameter must be different than 1 when automatic " - "initialization is used"); - } -} -*/ From 368ae87c7722d6dc7a521539d24a67e6f35076c1 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 11 Apr 2022 17:29:32 +0200 Subject: [PATCH 246/317] Trigger checks --- src/hierarchies/priors/fa_prior_model.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index 5f4b97a66..22330ece3 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -126,7 +126,7 @@ void FAPriorModel::initialize_hypers() { throw std::invalid_argument("Number of factors must be > 0"); } } - + // blabla else { throw std::invalid_argument("Unrecognized hierarchy prior"); } From 58ff956060f463faa868565e9553db5fac3c014f Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 11 Apr 2022 17:29:57 +0200 Subject: [PATCH 247/317] Revert changes --- src/hierarchies/priors/fa_prior_model.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index 22330ece3..9bc3093cd 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -125,9 +125,7 @@ void FAPriorModel::initialize_hypers() { if (hypers->q <= 0) { throw std::invalid_argument("Number of factors must be > 0"); } - } - // blabla - else { + } else { throw std::invalid_argument("Unrecognized hierarchy prior"); } } From f9b99bb5db1f867f077e0b15756a383d5e02aabf Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 11 Apr 2022 21:03:29 +0200 Subject: [PATCH 248/317] Back to Release mode --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 09cbcfb7b..a4d7cea61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ project(bayesmix) set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) +set(CMAKE_BUILD_TYPE Release) set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) set(CMAKE_CXX_FLAGS_RELEASE "-O3 -funroll-loops -ftree-vectorize") From f439ec2c50f1b3881d0dcc878c19378c4febecd4 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 14 Apr 2022 09:14:18 +0200 Subject: [PATCH 249/317] performance improvements on nnig hierarchy --- src/hierarchies/likelihoods/base_likelihood.h | 2 ++ .../likelihoods/states/base_state.h | 35 +++++++++++++++++++ .../likelihoods/uni_norm_likelihood.cc | 13 +++++++ .../likelihoods/uni_norm_likelihood.h | 3 +- src/hierarchies/priors/base_prior_model.h | 33 ++++++++++------- src/hierarchies/priors/nig_prior_model.cc | 6 ++-- src/hierarchies/priors/nig_prior_model.h | 12 +++---- .../updaters/semi_conjugate_updater.h | 4 +-- 8 files changed, 83 insertions(+), 25 deletions(-) create mode 100644 src/hierarchies/likelihoods/states/base_state.h diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 1c67f6ee7..60e1442c6 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -84,6 +84,8 @@ class BaseLikelihood : public AbstractLikelihood { //! Returns the class of the current state for the likelihood State get_state() const { return state; } + State* mutable_state() { return &state; } + //! Returns a vector storing the state in its unconstrained form Eigen::VectorXd get_unconstrained_state() override { return internal::get_unconstrained_state(state, 0); diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h new file mode 100644 index 000000000..9e7a4fd8a --- /dev/null +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -0,0 +1,35 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_BASE_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_BASE_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "src/utils/proto_utils.h" + +namespace States { + +class BaseState { + public: + int card; + + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + virtual Eigen::VectorXd get_unconstrained() { return Eigen::VectorXd(0); } + + virtual void set_from_unconstrained(Eigen::VectorXd in) { } + + virtual void set_from_proto(const ProtoState &state_, bool update_card) = 0; + + virtual ProtoState get_as_proto() = 0; + + std::shared_ptr to_proto() { + return std::make_shared(get_as_proto()); + } + + virtual double log_det_jac() { return -1; } +}; + +} // namespace States + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_BASE_STATE_H_ diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index b0ceada1f..5ae5acfa8 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -23,6 +23,19 @@ void UniNormLikelihood::set_state_from_proto( if (update_card) set_card(statecast.cardinality()); } +void UniNormLikelihood::set_state( + const States::UniLS &state_, bool update_card) { + + int old_card; + if (! update_card) { + old_card = state.card; + } + state = state_; + if (! update_card) { + state.card = old_card; + } +} + std::shared_ptr UniNormLikelihood::get_state_proto() const { auto out = std::make_shared(); diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 02de25ba5..be4c954e3 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -12,7 +12,7 @@ #include "states/includes.h" class UniNormLikelihood - : public BaseLikelihood { + : public BaseLikelihood { public: UniNormLikelihood() = default; ~UniNormLikelihood() = default; @@ -20,6 +20,7 @@ class UniNormLikelihood bool is_dependent() const override { return false; }; void set_state_from_proto(const google::protobuf::Message &state_, bool update_card = true) override; + void set_state(const States::UniLS &state_, bool update_card = true); void clear_summary_statistics() override; double get_data_sum() const { return data_sum; }; double get_data_sum_squares() const { return data_sum_squares; }; diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index efdd70e30..a2c3e970f 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -15,7 +15,7 @@ #include "prior_model_internal.h" #include "src/utils/rng.h" -template +template class BasePriorModel : public AbstractPriorModel { public: //! Default constructor @@ -54,6 +54,13 @@ class BasePriorModel : public AbstractPriorModel { static_cast(*this), unconstrained_params, 0); }; + virtual State sample(ProtoHypersPtr hier_hypers = nullptr) = 0; + + std::shared_ptr sample_proto( + ProtoHypersPtr hier_hypers = nullptr) override { + return sample(hier_hypers).to_proto(); + } + //! Returns an independent, data-less copy of this object std::shared_ptr clone() const override; @@ -113,16 +120,16 @@ class BasePriorModel : public AbstractPriorModel { }; /* *** Methods Definitions *** */ -template +template std::shared_ptr -BasePriorModel::clone() const { +BasePriorModel::clone() const { auto out = std::make_shared(static_cast(*this)); return out; } -template +template std::shared_ptr -BasePriorModel::deep_clone() const { +BasePriorModel::deep_clone() const { auto out = std::make_shared(static_cast(*this)); // Prior Deep-clone @@ -142,17 +149,17 @@ BasePriorModel::deep_clone() const { return out; } -template +template google::protobuf::Message * -BasePriorModel::get_mutable_prior() { +BasePriorModel::get_mutable_prior() { if (prior == nullptr) { create_empty_prior(); } return prior.get(); } -template -void BasePriorModel::write_hypers_to_proto( +template +void BasePriorModel::write_hypers_to_proto( google::protobuf::Message *out) const { std::shared_ptr hypers_ = get_hypers_proto(); @@ -160,15 +167,15 @@ void BasePriorModel::write_hypers_to_proto( out_cast->CopyFrom(*hypers_.get()); } -template -void BasePriorModel::initialize() { +template +void BasePriorModel::initialize() { check_prior_is_set(); create_empty_hypers(); initialize_hypers(); } -template -void BasePriorModel::check_prior_is_set() const { +template +void BasePriorModel::check_prior_is_set() const { if (prior == nullptr) { throw std::invalid_argument("Hierarchy prior was not provided"); } diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index 31b5ea7be..acb11d123 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -84,17 +84,17 @@ double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } -std::shared_ptr NIGPriorModel::sample( +States::UniLS NIGPriorModel::sample( ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); auto params = (hier_hypers) ? hier_hypers->nnig_state() : get_hypers_proto()->nnig_state(); - State::UniLS out; + States::UniLS out; out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); out.mean = stan::math::normal_rng(params.mean(), sqrt(out.var / params.var_scaling()), rng); - return out.to_proto(); + return out; } void NIGPriorModel::update_hypers( diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index 2c4357955..00d66986d 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -11,8 +11,9 @@ #include "hyperparams.h" #include "src/utils/rng.h" -class NIGPriorModel : public BasePriorModel { +class NIGPriorModel + : public BasePriorModel { public: using AbstractPriorModel::ProtoHypers; using AbstractPriorModel::ProtoHypersPtr; @@ -26,8 +27,8 @@ class NIGPriorModel : public BasePriorModel &unconstrained_params) const { Eigen::Matrix constrained_params = - State::uni_ls_to_constrained(unconstrained_params); - T log_det_jac = State::uni_ls_log_det_jac(constrained_params); + States::uni_ls_to_constrained(unconstrained_params); + T log_det_jac = States::uni_ls_log_det_jac(constrained_params); T mean = constrained_params(0); T var = constrained_params(1); T lpdf = stan::math::normal_lpdf(mean, hypers->mean, @@ -37,8 +38,7 @@ class NIGPriorModel : public BasePriorModel sample( - ProtoHypersPtr hier_hypers = nullptr) override; + States::UniLS sample(ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override; diff --git a/src/hierarchies/updaters/semi_conjugate_updater.h b/src/hierarchies/updaters/semi_conjugate_updater.h index 18fa0a63f..6517025a0 100644 --- a/src/hierarchies/updaters/semi_conjugate_updater.h +++ b/src/hierarchies/updaters/semi_conjugate_updater.h @@ -47,10 +47,10 @@ void SemiConjugateUpdater::draw( // Sample from the full conditional of a semi-conjugate hierarchy bool set_card = true; /*, use_post_hypers=true;*/ if (likecast.get_card() == 0) { - likecast.set_state_from_proto(*priorcast.sample(), !set_card); + likecast.set_state(priorcast.sample(), !set_card); } else { auto post_params = compute_posterior_hypers(likecast, priorcast); - likecast.set_state_from_proto(*priorcast.sample(post_params), !set_card); + likecast.set_state(priorcast.sample(post_params), !set_card); if (update_params) save_posterior_hypers(post_params); } } From 5e3b8ad495182dc4e4c4658725ce418b77e08fd1 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 14 Apr 2022 09:27:34 +0200 Subject: [PATCH 250/317] code cleanup --- src/hierarchies/base_hierarchy.h | 2 +- src/hierarchies/likelihoods/base_likelihood.h | 16 ++++++++++++-- .../likelihoods/uni_norm_likelihood.cc | 21 ------------------- .../likelihoods/uni_norm_likelihood.h | 3 --- 4 files changed, 15 insertions(+), 27 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 4e9591a5a..edf8a26d6 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -211,7 +211,7 @@ class BaseHierarchy : public AbstractHierarchy { //! Generates new state values from the centering prior distribution void sample_prior() override { - like->set_state_from_proto(*prior->sample(), false); + like->set_state(prior->sample(), false); }; //! Generates new state values from the centering posterior distribution diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index 60e1442c6..a6300842a 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -84,7 +84,7 @@ class BaseLikelihood : public AbstractLikelihood { //! Returns the class of the current state for the likelihood State get_state() const { return state; } - State* mutable_state() { return &state; } + State *mutable_state() { return &state; } //! Returns a vector storing the state in its unconstrained form Eigen::VectorXd get_unconstrained_state() override { @@ -92,7 +92,19 @@ class BaseLikelihood : public AbstractLikelihood { } //! Updates the state of the likelihood with the object given as input - void set_state(const State &_state) { state = _state; }; + void set_state(const State & state_, bool update_card = true) { + state = state_; + if (update_card) { + set_card(state.card); + } + }; + + void set_state_from_proto(const google::protobuf::Message &state_, + bool update_card = true) override { + State new_state; + new_state.set_from_proto(downcast_state(state_), update_card); + set_state(new_state, update_card); + } //! Updates the state of the likelihood starting from its unconstrained form void set_state_from_unconstrained( diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index 5ae5acfa8..68f1ec51d 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -15,27 +15,6 @@ void UniNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, } } -void UniNormLikelihood::set_state_from_proto( - const google::protobuf::Message &state_, bool update_card) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.var = statecast.uni_ls_state().var(); - if (update_card) set_card(statecast.cardinality()); -} - -void UniNormLikelihood::set_state( - const States::UniLS &state_, bool update_card) { - - int old_card; - if (! update_card) { - old_card = state.card; - } - state = state_; - if (! update_card) { - state.card = old_card; - } -} - std::shared_ptr UniNormLikelihood::get_state_proto() const { auto out = std::make_shared(); diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index be4c954e3..4f299993c 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -18,9 +18,6 @@ class UniNormLikelihood ~UniNormLikelihood() = default; bool is_multivariate() const override { return false; }; bool is_dependent() const override { return false; }; - void set_state_from_proto(const google::protobuf::Message &state_, - bool update_card = true) override; - void set_state(const States::UniLS &state_, bool update_card = true); void clear_summary_statistics() override; double get_data_sum() const { return data_sum; }; double get_data_sum_squares() const { return data_sum_squares; }; From 44976108934aeca0e384dfaf8a24f69b8dfc1dd2 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 14 Apr 2022 09:32:40 +0200 Subject: [PATCH 251/317] reverted naming convention --- .../likelihoods/states/base_state.h | 8 ++++---- .../likelihoods/states/uni_ls_state.h | 20 +++++++++---------- .../likelihoods/uni_norm_likelihood.h | 2 +- src/hierarchies/priors/abstract_prior_model.h | 5 +---- src/hierarchies/priors/base_prior_model.h | 19 +++++++++--------- src/hierarchies/priors/nig_prior_model.cc | 5 ++--- src/hierarchies/priors/nig_prior_model.h | 8 ++++---- 7 files changed, 32 insertions(+), 35 deletions(-) diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h index 9e7a4fd8a..f7883f058 100644 --- a/src/hierarchies/likelihoods/states/base_state.h +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -7,17 +7,17 @@ #include "algorithm_state.pb.h" #include "src/utils/proto_utils.h" -namespace States { +namespace State { class BaseState { public: int card; - + using ProtoState = bayesmix::AlgorithmState::ClusterState; virtual Eigen::VectorXd get_unconstrained() { return Eigen::VectorXd(0); } - virtual void set_from_unconstrained(Eigen::VectorXd in) { } + virtual void set_from_unconstrained(Eigen::VectorXd in) {} virtual void set_from_proto(const ProtoState &state_, bool update_card) = 0; @@ -30,6 +30,6 @@ class BaseState { virtual double log_det_jac() { return -1; } }; -} // namespace States +} // namespace State #endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_BASE_STATE_H_ diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h index a37be6280..1c254c11f 100644 --- a/src/hierarchies/likelihoods/states/uni_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -6,6 +6,7 @@ #include "algorithm_state.pb.h" #include "src/utils/proto_utils.h" +#include "base_state.h" namespace State { @@ -32,41 +33,40 @@ T uni_ls_log_det_jac(Eigen::Matrix constrained) { return out; } -class UniLS { +class UniLS: public BaseState { public: double mean, var; using ProtoState = bayesmix::AlgorithmState::ClusterState; - Eigen::VectorXd get_unconstrained() { + Eigen::VectorXd get_unconstrained() override { Eigen::VectorXd temp(2); temp << mean, var; return uni_ls_to_unconstrained(temp); } - void set_from_unconstrained(Eigen::VectorXd in) { + void set_from_unconstrained(Eigen::VectorXd in) override { Eigen::VectorXd temp = uni_ls_to_constrained(in); mean = temp(0); var = temp(1); } - void set_from_proto(const ProtoState &state_) { + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } mean = state_.uni_ls_state().mean(); var = state_.uni_ls_state().var(); } - ProtoState get_as_proto() { + ProtoState get_as_proto() override { ProtoState state; state.mutable_uni_ls_state()->set_mean(mean); state.mutable_uni_ls_state()->set_var(var); return state; } - std::shared_ptr to_proto() { - return std::make_shared(get_as_proto()); - } - - double log_det_jac() { + double log_det_jac() override { Eigen::VectorXd temp(2); temp << mean, var; return uni_ls_log_det_jac(temp); diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 4f299993c..d84cdaa89 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -12,7 +12,7 @@ #include "states/includes.h" class UniNormLikelihood - : public BaseLikelihood { + : public BaseLikelihood { public: UniNormLikelihood() = default; ~UniNormLikelihood() = default; diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 412c51ae3..c9ca2f01d 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -68,10 +68,7 @@ class AbstractPriorModel { //! parameters to use for the sampling. The default behaviour (i.e. //! `hier_hypers = nullptr`) uses prior hyperparameters //! @return A Protobuf message storing the state sampled from the prior model - // virtual std::shared_ptr sample( - // bool use_post_hypers) = 0; - - virtual std::shared_ptr sample( + virtual std::shared_ptr sample_proto( ProtoHypersPtr hier_hypers = nullptr) = 0; //! Updates hyperparameter values given a vector of cluster states diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index a2c3e970f..cda5b6f28 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -57,8 +57,8 @@ class BasePriorModel : public AbstractPriorModel { virtual State sample(ProtoHypersPtr hier_hypers = nullptr) = 0; std::shared_ptr sample_proto( - ProtoHypersPtr hier_hypers = nullptr) override { - return sample(hier_hypers).to_proto(); + ProtoHypersPtr hier_hypers = nullptr) override { + return sample(hier_hypers).to_proto(); } //! Returns an independent, data-less copy of this object @@ -120,14 +120,14 @@ class BasePriorModel : public AbstractPriorModel { }; /* *** Methods Definitions *** */ -template +template std::shared_ptr BasePriorModel::clone() const { auto out = std::make_shared(static_cast(*this)); return out; } -template +template std::shared_ptr BasePriorModel::deep_clone() const { auto out = std::make_shared(static_cast(*this)); @@ -149,7 +149,7 @@ BasePriorModel::deep_clone() const { return out; } -template +template google::protobuf::Message * BasePriorModel::get_mutable_prior() { if (prior == nullptr) { @@ -158,7 +158,7 @@ BasePriorModel::get_mutable_prior() { return prior.get(); } -template +template void BasePriorModel::write_hypers_to_proto( google::protobuf::Message *out) const { std::shared_ptr hypers_ = @@ -167,15 +167,16 @@ void BasePriorModel::write_hypers_to_proto( out_cast->CopyFrom(*hypers_.get()); } -template +template void BasePriorModel::initialize() { check_prior_is_set(); create_empty_hypers(); initialize_hypers(); } -template -void BasePriorModel::check_prior_is_set() const { +template +void BasePriorModel::check_prior_is_set() + const { if (prior == nullptr) { throw std::invalid_argument("Hierarchy prior was not provided"); } diff --git a/src/hierarchies/priors/nig_prior_model.cc b/src/hierarchies/priors/nig_prior_model.cc index acb11d123..41756e89c 100644 --- a/src/hierarchies/priors/nig_prior_model.cc +++ b/src/hierarchies/priors/nig_prior_model.cc @@ -84,13 +84,12 @@ double NIGPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } -States::UniLS NIGPriorModel::sample( - ProtoHypersPtr hier_hypers) { +State::UniLS NIGPriorModel::sample(ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); auto params = (hier_hypers) ? hier_hypers->nnig_state() : get_hypers_proto()->nnig_state(); - States::UniLS out; + State::UniLS out; out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); out.mean = stan::math::normal_rng(params.mean(), sqrt(out.var / params.var_scaling()), rng); diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index 00d66986d..45b61189e 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -12,7 +12,7 @@ #include "src/utils/rng.h" class NIGPriorModel - : public BasePriorModel { public: using AbstractPriorModel::ProtoHypers; @@ -27,8 +27,8 @@ class NIGPriorModel T lpdf_from_unconstrained( const Eigen::Matrix &unconstrained_params) const { Eigen::Matrix constrained_params = - States::uni_ls_to_constrained(unconstrained_params); - T log_det_jac = States::uni_ls_log_det_jac(constrained_params); + State::uni_ls_to_constrained(unconstrained_params); + T log_det_jac = State::uni_ls_log_det_jac(constrained_params); T mean = constrained_params(0); T var = constrained_params(1); T lpdf = stan::math::normal_lpdf(mean, hypers->mean, @@ -38,7 +38,7 @@ class NIGPriorModel return lpdf + log_det_jac; } - States::UniLS sample(ProtoHypersPtr hier_hypers = nullptr) override; + State::UniLS sample(ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override; From e5277253fc76aa61b362940f9f7434d405011701 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 18 Apr 2022 16:59:57 +0100 Subject: [PATCH 252/317] nxig working --- src/hierarchies/base_hierarchy.h | 7 +++---- src/hierarchies/likelihoods/base_likelihood.h | 14 ++++++-------- src/hierarchies/likelihoods/states/CMakeLists.txt | 7 ++++--- src/hierarchies/likelihoods/states/base_state.h | 2 +- src/hierarchies/likelihoods/states/uni_ls_state.h | 6 +++--- src/hierarchies/priors/nxig_prior_model.cc | 5 ++--- src/hierarchies/priors/nxig_prior_model.h | 8 ++++---- 7 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index edf8a26d6..3206c4967 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -210,9 +210,7 @@ class BaseHierarchy : public AbstractHierarchy { } //! Generates new state values from the centering prior distribution - void sample_prior() override { - like->set_state(prior->sample(), false); - }; + void sample_prior() override { like->set_state(prior->sample(), false); }; //! Generates new state values from the centering posterior distribution //! @param update_params Save posterior hypers after the computation? @@ -260,7 +258,8 @@ class BaseHierarchy : public AbstractHierarchy { //! AlgoritmState::ClusterState message by adding the appropriate type std::shared_ptr get_state_proto() const override { - return like->get_state_proto(); + return std::make_shared( + like->get_state().get_as_proto()); } //! Returns a pointer to the Protobuf message of the prior of this cluster diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index a6300842a..d7da79553 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -92,7 +92,7 @@ class BaseLikelihood : public AbstractLikelihood { } //! Updates the state of the likelihood with the object given as input - void set_state(const State & state_, bool update_card = true) { + void set_state(const State &state_, bool update_card = true) { state = state_; if (update_card) { set_card(state.card); @@ -100,10 +100,10 @@ class BaseLikelihood : public AbstractLikelihood { }; void set_state_from_proto(const google::protobuf::Message &state_, - bool update_card = true) override { - State new_state; - new_state.set_from_proto(downcast_state(state_), update_card); - set_state(new_state, update_card); + bool update_card = true) override { + State new_state; + new_state.set_from_proto(downcast_state(state_), update_card); + set_state(new_state, update_card); } //! Updates the state of the likelihood starting from its unconstrained form @@ -199,10 +199,8 @@ void BaseLikelihood::remove_datum( template void BaseLikelihood::write_state_to_proto( google::protobuf::Message *out) const { - std::shared_ptr state_ = - get_state_proto(); auto *out_cast = downcast_state(out); - out_cast->CopyFrom(*state_.get()); + out_cast->CopyFrom(state.get_as_proto()); out_cast->set_cardinality(card); } diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt index 2ccfbf343..6ecc1eebe 100644 --- a/src/hierarchies/likelihoods/states/CMakeLists.txt +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -1,7 +1,8 @@ target_sources(bayesmix PUBLIC includes.h + base_state.h uni_ls_state.h - multi_ls_state.h - uni_lin_reg_ls_state.h - fa_state.h + # multi_ls_state.h + # uni_lin_reg_ls_state.h + # fa_state.h ) diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h index f7883f058..df7d992fd 100644 --- a/src/hierarchies/likelihoods/states/base_state.h +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -21,7 +21,7 @@ class BaseState { virtual void set_from_proto(const ProtoState &state_, bool update_card) = 0; - virtual ProtoState get_as_proto() = 0; + virtual ProtoState get_as_proto() const = 0; std::shared_ptr to_proto() { return std::make_shared(get_as_proto()); diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h index 1c254c11f..76a199663 100644 --- a/src/hierarchies/likelihoods/states/uni_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -5,8 +5,8 @@ #include #include "algorithm_state.pb.h" -#include "src/utils/proto_utils.h" #include "base_state.h" +#include "src/utils/proto_utils.h" namespace State { @@ -33,7 +33,7 @@ T uni_ls_log_det_jac(Eigen::Matrix constrained) { return out; } -class UniLS: public BaseState { +class UniLS : public BaseState { public: double mean, var; @@ -59,7 +59,7 @@ class UniLS: public BaseState { var = state_.uni_ls_state().var(); } - ProtoState get_as_proto() override { + ProtoState get_as_proto() const override { ProtoState state; state.mutable_uni_ls_state()->set_mean(mean); state.mutable_uni_ls_state()->set_var(var); diff --git a/src/hierarchies/priors/nxig_prior_model.cc b/src/hierarchies/priors/nxig_prior_model.cc index dc13f70d1..0b7b0cbea 100644 --- a/src/hierarchies/priors/nxig_prior_model.cc +++ b/src/hierarchies/priors/nxig_prior_model.cc @@ -30,8 +30,7 @@ double NxIGPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } -std::shared_ptr NxIGPriorModel::sample( - ProtoHypersPtr hier_hypers) { +State::UniLS NxIGPriorModel::sample(ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); auto params = (hier_hypers) ? hier_hypers->nnxig_state() : get_hypers_proto()->nnxig_state(); @@ -39,7 +38,7 @@ std::shared_ptr NxIGPriorModel::sample( State::UniLS out; out.mean = stan::math::normal_rng(params.mean(), sqrt(params.var()), rng); out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); - return out.to_proto(); + return out; }; void NxIGPriorModel::update_hypers( diff --git a/src/hierarchies/priors/nxig_prior_model.h b/src/hierarchies/priors/nxig_prior_model.h index 86a517c58..8f5e3f036 100644 --- a/src/hierarchies/priors/nxig_prior_model.h +++ b/src/hierarchies/priors/nxig_prior_model.h @@ -11,8 +11,9 @@ #include "hyperparams.h" #include "src/utils/rng.h" -class NxIGPriorModel : public BasePriorModel { +class NxIGPriorModel + : public BasePriorModel { public: using AbstractPriorModel::ProtoHypers; using AbstractPriorModel::ProtoHypersPtr; @@ -22,8 +23,7 @@ class NxIGPriorModel : public BasePriorModel sample( - ProtoHypersPtr hier_hypers = nullptr) override; + State::UniLS sample(ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override; From bf386ef3eb893b1717ba41c90629c0adb5a532d8 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 18 Apr 2022 17:13:58 +0100 Subject: [PATCH 253/317] mulit norm like working --- src/hierarchies/likelihoods/CMakeLists.txt | 12 +++++----- .../likelihoods/abstract_likelihood.h | 6 ----- .../likelihoods/multi_norm_likelihood.cc | 22 ------------------- .../likelihoods/multi_norm_likelihood.h | 5 ----- .../likelihoods/states/CMakeLists.txt | 2 +- .../likelihoods/states/multi_ls_state.h | 21 +++++++++--------- .../likelihoods/uni_norm_likelihood.cc | 8 ------- .../likelihoods/uni_norm_likelihood.h | 3 --- src/hierarchies/priors/nw_prior_model.cc | 5 ++--- src/hierarchies/priors/nw_prior_model.h | 8 +++---- 10 files changed, 24 insertions(+), 68 deletions(-) diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index df10e8674..4d1ae8718 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -6,12 +6,12 @@ target_sources(bayesmix PUBLIC uni_norm_likelihood.cc multi_norm_likelihood.h multi_norm_likelihood.cc - uni_lin_reg_likelihood.h - uni_lin_reg_likelihood.cc - laplace_likelihood.h - laplace_likelihood.cc - fa_likelihood.h - fa_likelihood.cc + # uni_lin_reg_likelihood.h + # uni_lin_reg_likelihood.cc + # laplace_likelihood.h + # laplace_likelihood.cc + # fa_likelihood.h + # fa_likelihood.cc ) add_subdirectory(states) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 8239d25c7..bebb3804b 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -117,12 +117,6 @@ class AbstractLikelihood { virtual Eigen::VectorXd get_unconstrained_state() = 0; protected: - //! Writes current state to a Protobuf message and return a shared_ptr - //! New hierarchies have to first modify the field 'oneof val' in the - //! AlgoritmState::ClusterState message by adding the appropriate type - virtual std::shared_ptr - get_state_proto() const = 0; - //! Evaluates the log-likelihood of data in a single point //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.cc b/src/hierarchies/likelihoods/multi_norm_likelihood.cc index 8c12942aa..ae51f3d47 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.cc @@ -25,28 +25,6 @@ void MultiNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, } } -void MultiNormLikelihood::set_state_from_proto( - const google::protobuf::Message &state_, bool update_card) { - auto &statecast = downcast_state(state_); - state.mean = to_eigen(statecast.multi_ls_state().mean()); - state.prec = to_eigen(statecast.multi_ls_state().prec()); - state.prec_chol = to_eigen(statecast.multi_ls_state().prec_chol()); - Eigen::VectorXd diag = state.prec_chol.diagonal(); - state.prec_logdet = 2 * log(diag.array()).sum(); - if (update_card) set_card(statecast.cardinality()); -} - -std::shared_ptr -MultiNormLikelihood::get_state_proto() const { - bayesmix::MultiLSState state_; - bayesmix::to_proto(state.mean, state_.mutable_mean()); - bayesmix::to_proto(state.prec, state_.mutable_prec()); - bayesmix::to_proto(state.prec_chol, state_.mutable_prec_chol()); - auto out = std::make_shared(); - out->mutable_multi_ls_state()->CopyFrom(state_); - return out; -} - void MultiNormLikelihood::clear_summary_statistics() { data_sum = Eigen::VectorXd::Zero(dim); data_sum_squares = Eigen::MatrixXd::Zero(dim, dim); diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index 0ffc84cf3..1887c93f4 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -19,8 +19,6 @@ class MultiNormLikelihood ~MultiNormLikelihood() = default; bool is_multivariate() const override { return true; }; bool is_dependent() const override { return false; }; - void set_state_from_proto(const google::protobuf::Message &state_, - bool update_card = true) override; void clear_summary_statistics() override; void set_dim(unsigned int dim_) { @@ -31,9 +29,6 @@ class MultiNormLikelihood Eigen::VectorXd get_data_sum() const { return data_sum; }; Eigen::MatrixXd get_data_sum_squares() const { return data_sum_squares; }; - std::shared_ptr get_state_proto() - const override; - protected: double compute_lpdf(const Eigen::RowVectorXd &datum) const override; void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt index 6ecc1eebe..33ffae0ee 100644 --- a/src/hierarchies/likelihoods/states/CMakeLists.txt +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -2,7 +2,7 @@ target_sources(bayesmix PUBLIC includes.h base_state.h uni_ls_state.h - # multi_ls_state.h + multi_ls_state.h # uni_lin_reg_ls_state.h # fa_state.h ) diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index 1ef26f192..99230d8b8 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -5,6 +5,7 @@ #include #include "algorithm_state.pb.h" +#include "base_state.h" #include "src/utils/proto_utils.h" namespace State { @@ -44,7 +45,7 @@ T multi_ls_log_det_jac( return out; } -class MultiLS { +class MultiLS : public BaseState { public: Eigen::VectorXd mean; Eigen::MatrixXd prec, prec_chol; @@ -52,11 +53,11 @@ class MultiLS { using ProtoState = bayesmix::AlgorithmState::ClusterState; - Eigen::VectorXd get_unconstrained() { + Eigen::VectorXd get_unconstrained() override { return multi_ls_to_unconstrained(mean, prec); } - void set_from_unconstrained(Eigen::VectorXd in) { + void set_from_unconstrained(Eigen::VectorXd in) override { std::tie(mean, prec) = multi_ls_to_constrained(in); set_from_constrained(mean, prec); } @@ -69,7 +70,11 @@ class MultiLS { prec_logdet = 2 * log(diag.array()).sum(); } - void set_from_proto(const ProtoState &state_) { + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + mean = to_eigen(state_.multi_ls_state().mean()); prec = to_eigen(state_.multi_ls_state().prec()); prec_chol = to_eigen(state_.multi_ls_state().prec_chol()); @@ -77,7 +82,7 @@ class MultiLS { prec_logdet = 2 * log(diag.array()).sum(); } - ProtoState get_as_proto() { + ProtoState get_as_proto() const override { ProtoState state; bayesmix::to_proto(mean, state.mutable_multi_ls_state()->mutable_mean()); bayesmix::to_proto(prec, state.mutable_multi_ls_state()->mutable_prec()); @@ -86,11 +91,7 @@ class MultiLS { return state; } - std::shared_ptr to_proto() { - return std::make_shared(get_as_proto()); - } - - double log_det_jac() { return multi_ls_log_det_jac(prec); } + double log_det_jac() override { return multi_ls_log_det_jac(prec); } }; } // namespace State diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.cc b/src/hierarchies/likelihoods/uni_norm_likelihood.cc index 68f1ec51d..3b5cdf06e 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.cc @@ -15,14 +15,6 @@ void UniNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, } } -std::shared_ptr -UniNormLikelihood::get_state_proto() const { - auto out = std::make_shared(); - out->mutable_uni_ls_state()->set_mean(state.mean); - out->mutable_uni_ls_state()->set_var(state.var); - return out; -} - void UniNormLikelihood::clear_summary_statistics() { data_sum = 0; data_sum_squares = 0; diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index d84cdaa89..7aba9a959 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -34,9 +34,6 @@ class UniNormLikelihood return out; } - std::shared_ptr get_state_proto() - const override; - protected: double compute_lpdf(const Eigen::RowVectorXd &datum) const override; void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc index 342d2de4e..cfee03f4c 100644 --- a/src/hierarchies/priors/nw_prior_model.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -121,8 +121,7 @@ double NWPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } -std::shared_ptr NWPriorModel::sample( - ProtoHypersPtr hier_hypers) { +State::MultiLS NWPriorModel::sample(ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); auto params = (hier_hypers) ? hier_hypers->nnw_state() : get_hypers_proto()->nnw_state(); @@ -137,7 +136,7 @@ std::shared_ptr NWPriorModel::sample( out.mean = stan::math::multi_normal_prec_rng( mean, tau_new * params.var_scaling(), rng); write_prec_to_state(tau_new, &out); - return out.to_proto(); + return out; }; void NWPriorModel::update_hypers( diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h index e7c0f06dd..e72f0c99a 100644 --- a/src/hierarchies/priors/nw_prior_model.h +++ b/src/hierarchies/priors/nw_prior_model.h @@ -11,16 +11,16 @@ #include "hyperparams.h" #include "src/utils/rng.h" -class NWPriorModel : public BasePriorModel { +class NWPriorModel + : public BasePriorModel { public: NWPriorModel() = default; ~NWPriorModel() = default; double lpdf(const google::protobuf::Message &state_) override; - std::shared_ptr sample( - ProtoHypersPtr hier_hypers = nullptr) override; + State::MultiLS sample(ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override; From 6741c5e8fc96737d2f5d95e1c06925b8abacec15 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 18 Apr 2022 17:22:27 +0100 Subject: [PATCH 254/317] uni lin reg hierarchy ok --- .../likelihoods/states/uni_lin_reg_ls_state.h | 20 +++++++++---------- .../likelihoods/uni_lin_reg_likelihood.cc | 20 ------------------- .../likelihoods/uni_lin_reg_likelihood.h | 5 ----- src/hierarchies/priors/mnig_prior_model.cc | 5 ++--- src/hierarchies/priors/mnig_prior_model.h | 8 ++++---- 5 files changed, 16 insertions(+), 42 deletions(-) diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h index 1cbe36a5f..ba0db6350 100644 --- a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h @@ -5,6 +5,7 @@ #include #include "algorithm_state.pb.h" +#include "base_state.h" #include "src/utils/eigen_utils.h" #include "src/utils/proto_utils.h" @@ -36,33 +37,36 @@ T uni_lin_reg_log_det_jac(Eigen::Matrix constrained) { return out; } -class UniLinRegLS { +class UniLinRegLS : public BaseState { public: Eigen::VectorXd regression_coeffs; double var; using ProtoState = bayesmix::AlgorithmState::ClusterState; - Eigen::VectorXd get_unconstrained() { + Eigen::VectorXd get_unconstrained() override { Eigen::VectorXd temp(regression_coeffs.size() + 1); temp << regression_coeffs, var; return uni_lin_reg_to_unconstrained(temp); } - void set_from_unconstrained(Eigen::VectorXd in) { + void set_from_unconstrained(Eigen::VectorXd in) override { Eigen::VectorXd temp = uni_lin_reg_to_constrained(in); int dim = in.size() - 1; regression_coeffs = temp.head(dim); var = temp(dim); } - void set_from_proto(const ProtoState &state_) { + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } regression_coeffs = bayesmix::to_eigen(state_.lin_reg_uni_ls_state().regression_coeffs()); var = state_.lin_reg_uni_ls_state().var(); } - ProtoState get_as_proto() { + ProtoState get_as_proto() const override { ProtoState state; bayesmix::to_proto( regression_coeffs, @@ -71,11 +75,7 @@ class UniLinRegLS { return state; } - std::shared_ptr to_proto() { - return std::make_shared(get_as_proto()); - } - - double log_det_jac() { + double log_det_jac() override { Eigen::VectorXd temp(regression_coeffs.size() + 1); temp << regression_coeffs, var; return uni_lin_reg_log_det_jac(temp); diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc index a4ddfaab9..550eec5fd 100644 --- a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.cc @@ -2,32 +2,12 @@ #include "src/utils/eigen_utils.h" -void UniLinRegLikelihood::set_state_from_proto( - const google::protobuf::Message &state_, bool update_card) { - auto &statecast = downcast_state(state_); - state.regression_coeffs = - bayesmix::to_eigen(statecast.lin_reg_uni_ls_state().regression_coeffs()); - state.var = statecast.lin_reg_uni_ls_state().var(); - if (update_card) set_card(statecast.cardinality()); -} - void UniLinRegLikelihood::clear_summary_statistics() { mixed_prod = Eigen::VectorXd::Zero(dim); data_sum_squares = 0.0; covar_sum_squares = Eigen::MatrixXd::Zero(dim, dim); } -std::shared_ptr -UniLinRegLikelihood::get_state_proto() const { - bayesmix::LinRegUniLSState state_; - bayesmix::to_proto(state.regression_coeffs, - state_.mutable_regression_coeffs()); - state_.set_var(state.var); - auto out = std::make_shared(); - out->mutable_lin_reg_uni_ls_state()->CopyFrom(state_); - return out; -} - double UniLinRegLikelihood::compute_lpdf( const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h index 25f559e66..4f8d7eaea 100644 --- a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h @@ -18,8 +18,6 @@ class UniLinRegLikelihood ~UniLinRegLikelihood() = default; bool is_multivariate() const override { return false; }; bool is_dependent() const override { return true; }; - void set_state_from_proto(const google::protobuf::Message &state_, - bool update_card = true) override; void clear_summary_statistics() override; // Getters and Setters @@ -32,9 +30,6 @@ class UniLinRegLikelihood Eigen::MatrixXd get_covar_sum_squares() const { return covar_sum_squares; }; Eigen::VectorXd get_mixed_prod() const { return mixed_prod; }; - std::shared_ptr get_state_proto() - const override; - protected: double compute_lpdf(const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const override; diff --git a/src/hierarchies/priors/mnig_prior_model.cc b/src/hierarchies/priors/mnig_prior_model.cc index f65732080..2ec02169f 100644 --- a/src/hierarchies/priors/mnig_prior_model.cc +++ b/src/hierarchies/priors/mnig_prior_model.cc @@ -11,8 +11,7 @@ double MNIGPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } -std::shared_ptr MNIGPriorModel::sample( - ProtoHypersPtr hier_hypers) { +State::UniLinRegLS MNIGPriorModel::sample(ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); auto params = (hier_hypers) ? hier_hypers->lin_reg_uni_state() : get_hypers_proto()->lin_reg_uni_state(); @@ -23,7 +22,7 @@ std::shared_ptr MNIGPriorModel::sample( out.var = stan::math::inv_gamma_rng(params.shape(), params.scale(), rng); out.regression_coeffs = stan::math::multi_normal_prec_rng(mean, var_scaling / out.var, rng); - return out.to_proto(); + return out; } void MNIGPriorModel::update_hypers( diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index f6949fce7..b1a5aff66 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -10,8 +10,9 @@ #include "hyperparams.h" #include "src/utils/rng.h" -class MNIGPriorModel : public BasePriorModel { +class MNIGPriorModel + : public BasePriorModel { public: using AbstractPriorModel::ProtoHypers; using AbstractPriorModel::ProtoHypersPtr; @@ -21,8 +22,7 @@ class MNIGPriorModel : public BasePriorModel sample( - ProtoHypersPtr hier_hypers = nullptr) override; + State::UniLinRegLS sample(ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override; From c49beccdd9c981a507331d6f31b728c852e71fda Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 18 Apr 2022 17:43:11 +0100 Subject: [PATCH 255/317] laplace and fa compiling --- src/hierarchies/fa_hierarchy.h | 3 +- src/hierarchies/likelihoods/CMakeLists.txt | 12 +++--- src/hierarchies/likelihoods/fa_likelihood.cc | 36 ------------------ src/hierarchies/likelihoods/fa_likelihood.h | 10 ----- .../likelihoods/laplace_likelihood.cc | 16 -------- .../likelihoods/laplace_likelihood.h | 5 --- .../likelihoods/states/CMakeLists.txt | 2 +- src/hierarchies/likelihoods/states/fa_state.h | 38 ++++++++++++++++++- src/hierarchies/priors/fa_prior_model.cc | 11 +----- src/hierarchies/priors/fa_prior_model.h | 6 +-- src/hierarchies/updaters/fa_updater.cc | 14 +------ 11 files changed, 52 insertions(+), 101 deletions(-) diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index 136761273..15d278023 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -32,9 +32,8 @@ class FAHierarchy state.psi = hypers.beta / (hypers.alpha0 + 1.); state.lambda = Eigen::MatrixXd::Zero(dim, hypers.q); state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); + state.compute_wood_factors(); like->set_state(state); - like->compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, - state.psi_inverse); } }; diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index 4d1ae8718..df10e8674 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -6,12 +6,12 @@ target_sources(bayesmix PUBLIC uni_norm_likelihood.cc multi_norm_likelihood.h multi_norm_likelihood.cc - # uni_lin_reg_likelihood.h - # uni_lin_reg_likelihood.cc - # laplace_likelihood.h - # laplace_likelihood.cc - # fa_likelihood.h - # fa_likelihood.cc + uni_lin_reg_likelihood.h + uni_lin_reg_likelihood.cc + laplace_likelihood.h + laplace_likelihood.cc + fa_likelihood.h + fa_likelihood.cc ) add_subdirectory(states) diff --git a/src/hierarchies/likelihoods/fa_likelihood.cc b/src/hierarchies/likelihoods/fa_likelihood.cc index ae843598b..81973647d 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.cc +++ b/src/hierarchies/likelihoods/fa_likelihood.cc @@ -2,36 +2,10 @@ #include "src/utils/distributions.h" -void FALikelihood::set_state_from_proto( - const google::protobuf::Message& state_, bool update_card) { - auto& statecast = downcast_state(state_); - state.mu = bayesmix::to_eigen(statecast.fa_state().mu()); - state.psi = bayesmix::to_eigen(statecast.fa_state().psi()); - state.eta = bayesmix::to_eigen(statecast.fa_state().eta()); - state.lambda = bayesmix::to_eigen(statecast.fa_state().lambda()); - state.psi_inverse = state.psi.cwiseInverse().asDiagonal(); - compute_wood_factors(state.cov_wood, state.cov_logdet, state.lambda, - state.psi_inverse); - if (update_card) set_card(statecast.cardinality()); -} - void FALikelihood::clear_summary_statistics() { data_sum = Eigen::VectorXd::Zero(dim); } -std::shared_ptr -FALikelihood::get_state_proto() const { - bayesmix::FAState state_; - bayesmix::to_proto(state.mu, state_.mutable_mu()); - bayesmix::to_proto(state.psi, state_.mutable_psi()); - bayesmix::to_proto(state.eta, state_.mutable_eta()); - bayesmix::to_proto(state.lambda, state_.mutable_lambda()); - - auto out = std::make_shared(); - out->mutable_fa_state()->CopyFrom(state_); - return out; -} - double FALikelihood::compute_lpdf(const Eigen::RowVectorXd& datum) const { return bayesmix::multi_normal_lpdf_woodbury_chol( datum, state.mu, state.psi_inverse, state.cov_wood, state.cov_logdet); @@ -45,13 +19,3 @@ void FALikelihood::update_sum_stats(const Eigen::RowVectorXd& datum, data_sum -= datum; } } - -void FALikelihood::compute_wood_factors( - Eigen::MatrixXd& cov_wood, double& cov_logdet, - const Eigen::MatrixXd& lambda, - const Eigen::DiagonalMatrix& psi_inverse) { - auto [cov_wood_, cov_logdet_] = - bayesmix::compute_wood_chol_and_logdet(psi_inverse, lambda); - cov_logdet = cov_logdet_; - cov_wood = cov_wood_; -} diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h index 83a9c2753..926c022f9 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.h +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -18,8 +18,6 @@ class FALikelihood : public BaseLikelihood { ~FALikelihood() = default; bool is_multivariate() const override { return true; }; bool is_dependent() const override { return false; }; - void set_state_from_proto(const google::protobuf::Message& state_, - bool update_card = true) override; void clear_summary_statistics() override; void set_dim(unsigned int dim_) { dim = dim_; @@ -28,14 +26,6 @@ class FALikelihood : public BaseLikelihood { unsigned int get_dim() const { return dim; }; Eigen::VectorXd get_data_sum() const { return data_sum; }; - std::shared_ptr get_state_proto() - const override; - - void compute_wood_factors( - Eigen::MatrixXd& cov_wood, double& cov_logdet, - const Eigen::MatrixXd& lambda, - const Eigen::DiagonalMatrix& psi_inverse); - protected: double compute_lpdf(const Eigen::RowVectorXd& datum) const override; void update_sum_stats(const Eigen::RowVectorXd& datum, bool add) override; diff --git a/src/hierarchies/likelihoods/laplace_likelihood.cc b/src/hierarchies/likelihoods/laplace_likelihood.cc index bbd6c799c..7d99a7efe 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.cc +++ b/src/hierarchies/likelihoods/laplace_likelihood.cc @@ -4,19 +4,3 @@ double LaplaceLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { return stan::math::double_exponential_lpdf( datum(0), state.mean, stan::math::sqrt(state.var / 2.0)); } - -void LaplaceLikelihood::set_state_from_proto( - const google::protobuf::Message &state_, bool update_card) { - auto &statecast = downcast_state(state_); - state.mean = statecast.uni_ls_state().mean(); - state.var = statecast.uni_ls_state().var(); - if (update_card) set_card(statecast.cardinality()); -} - -std::shared_ptr -LaplaceLikelihood::get_state_proto() const { - auto out = std::make_shared(); - out->mutable_uni_ls_state()->set_mean(state.mean); - out->mutable_uni_ls_state()->set_var(state.var); - return out; -} diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 15e2124ed..26b9243a1 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -18,8 +18,6 @@ class LaplaceLikelihood ~LaplaceLikelihood() = default; bool is_multivariate() const override { return false; }; bool is_dependent() const override { return false; }; - void set_state_from_proto(const google::protobuf::Message &state_, - bool update_card = true) override; void clear_summary_statistics() override { return; }; template @@ -39,9 +37,6 @@ class LaplaceLikelihood return out; } - std::shared_ptr get_state_proto() - const override; - protected: double compute_lpdf(const Eigen::RowVectorXd &datum) const override; void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override { diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt index 33ffae0ee..4f850a603 100644 --- a/src/hierarchies/likelihoods/states/CMakeLists.txt +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -3,6 +3,6 @@ target_sources(bayesmix PUBLIC base_state.h uni_ls_state.h multi_ls_state.h - # uni_lin_reg_ls_state.h + uni_lin_reg_ls_state.h # fa_state.h ) diff --git a/src/hierarchies/likelihoods/states/fa_state.h b/src/hierarchies/likelihoods/states/fa_state.h index a876aaf1e..8f299859c 100644 --- a/src/hierarchies/likelihoods/states/fa_state.h +++ b/src/hierarchies/likelihoods/states/fa_state.h @@ -5,17 +5,53 @@ #include #include "algorithm_state.pb.h" +#include "base_state.h" +#include "src/utils/distributions.h" #include "src/utils/eigen_utils.h" #include "src/utils/proto_utils.h" namespace State { -class FA { +class FA : public BaseState { public: Eigen::VectorXd mu, psi; Eigen::MatrixXd eta, lambda, cov_wood; Eigen::DiagonalMatrix psi_inverse; double cov_logdet; + + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + + mu = bayesmix::to_eigen(state_.fa_state().mu()); + psi = bayesmix::to_eigen(state_.fa_state().psi()); + eta = bayesmix::to_eigen(state_.fa_state().eta()); + lambda = bayesmix::to_eigen(state_.fa_state().lambda()); + psi_inverse = psi.cwiseInverse().asDiagonal(); + compute_wood_factors(); + } + + void compute_wood_factors() { + auto [cov_wood_, cov_logdet_] = + bayesmix::compute_wood_chol_and_logdet(psi_inverse, lambda); + cov_logdet = cov_logdet_; + cov_wood = cov_wood_; + } + + ProtoState get_as_proto() const override { + bayesmix::FAState state_; + bayesmix::to_proto(mu, state_.mutable_mu()); + bayesmix::to_proto(psi, state_.mutable_psi()); + bayesmix::to_proto(eta, state_.mutable_eta()); + bayesmix::to_proto(lambda, state_.mutable_lambda()); + + bayesmix::AlgorithmState::ClusterState out; + out.mutable_fa_state()->CopyFrom(state_); + return out; + } }; } // namespace State diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index 9bc3093cd..d75442795 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -29,8 +29,7 @@ double FAPriorModel::lpdf(const google::protobuf::Message &state_) { return target; } -std::shared_ptr FAPriorModel::sample( - ProtoHypersPtr hier_hypers) { +State::FA FAPriorModel::sample(ProtoHypersPtr hier_hypers) { // Random seed auto &rng = bayesmix::Rng::Instance().get(); @@ -53,13 +52,7 @@ std::shared_ptr FAPriorModel::sample( out.lambda(j, i) = stan::math::normal_rng(0, 1, rng); } } - - // Eigen2Proto conversion - bayesmix::AlgorithmState::ClusterState state; - bayesmix::to_proto(out.mu, state.mutable_fa_state()->mutable_mu()); - bayesmix::to_proto(out.psi, state.mutable_fa_state()->mutable_psi()); - bayesmix::to_proto(out.lambda, state.mutable_fa_state()->mutable_lambda()); - return std::make_shared(state); + return out; } void FAPriorModel::update_hypers( diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index 66088d5f4..dd379a414 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -11,7 +11,8 @@ #include "src/utils/rng.h" class FAPriorModel - : public BasePriorModel { + : public BasePriorModel { public: using AbstractPriorModel::ProtoHypers; using AbstractPriorModel::ProtoHypersPtr; @@ -21,8 +22,7 @@ class FAPriorModel double lpdf(const google::protobuf::Message &state_) override; - std::shared_ptr sample( - ProtoHypersPtr hier_hypers = nullptr) override; + State::FA sample(ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override; diff --git a/src/hierarchies/updaters/fa_updater.cc b/src/hierarchies/updaters/fa_updater.cc index e88208b44..24d3408a8 100644 --- a/src/hierarchies/updaters/fa_updater.cc +++ b/src/hierarchies/updaters/fa_updater.cc @@ -10,7 +10,7 @@ void FAUpdater::draw(AbstractLikelihood& like, AbstractPriorModel& prior, // Sample from the full conditional of the fa hierarchy bool set_card = true, use_post_hypers = true; if (likecast.get_card() == 0) { - likecast.set_state_from_proto(*priorcast.sample(), !set_card); + likecast.set_state(priorcast.sample(), !set_card); } else { // Get state and hypers State::FA new_state = likecast.get_state(); @@ -20,17 +20,7 @@ void FAUpdater::draw(AbstractLikelihood& like, AbstractPriorModel& prior, sample_mu(new_state, hypers, likecast); sample_psi(new_state, hypers, likecast); sample_lambda(new_state, hypers, likecast); - // Eigen2Proto conversion - bayesmix::AlgorithmState::ClusterState new_state_proto; - bayesmix::to_proto(new_state.eta, - new_state_proto.mutable_fa_state()->mutable_eta()); - bayesmix::to_proto(new_state.mu, - new_state_proto.mutable_fa_state()->mutable_mu()); - bayesmix::to_proto(new_state.psi, - new_state_proto.mutable_fa_state()->mutable_psi()); - bayesmix::to_proto(new_state.lambda, - new_state_proto.mutable_fa_state()->mutable_lambda()); - likecast.set_state_from_proto(new_state_proto, !set_card); + likecast.set_state(new_state, !set_card); } } From 54d64941d833bf916638bbd089e47812e9b7b151 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 18 Apr 2022 17:43:28 +0100 Subject: [PATCH 256/317] reverted changes --- src/utils/testing_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/testing_utils.cc b/src/utils/testing_utils.cc index 2de404271..8a5d8a465 100644 --- a/src/utils/testing_utils.cc +++ b/src/utils/testing_utils.cc @@ -57,7 +57,7 @@ std::shared_ptr get_algorithm(const std::string& id, int dim) { if (dim == 1) { hier = get_univariate_nnig_hierarchy(); } else { - hier = get_multivariate_nnw_hierarchy(dim); + // hier = get_multivariate_nnw_hierarchy(dim); } hier->initialize(); algo->set_mixing(mixing); From 57d843b18847c72281e6b88d0829d22bba4725ca Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 18 Apr 2022 17:52:36 +0100 Subject: [PATCH 257/317] test working --- test/prior_models.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/prior_models.cc b/test/prior_models.cc index 20bff0d6c..be039e4f8 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -125,7 +125,8 @@ TEST(nig_prior_model, sample) { auto state2 = prior->sample(); // Check if they coincides - ASSERT_TRUE(state1->DebugString() != state2->DebugString()); + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); } TEST(nxig_prior_model, set_get_hypers) { @@ -207,7 +208,8 @@ TEST(nxig_prior_model, sample) { auto state2 = prior->sample(); // Check if they coincides - ASSERT_TRUE(state1->DebugString() != state2->DebugString()); + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); } TEST(nig_prior_model, unconstrained_lpdf) { @@ -382,7 +384,8 @@ TEST(nw_prior_model, sample) { auto state2 = prior->sample(); // Check if they coincides - ASSERT_TRUE(state1->DebugString() != state2->DebugString()); + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); } TEST(mnig_prior_model, set_get_hypers) { @@ -474,5 +477,6 @@ TEST(mnig_prior_model, sample) { auto state2 = prior->sample(); // Check if they coincides - ASSERT_TRUE(state1->DebugString() != state2->DebugString()); + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); } From cc0351cecb056b374b7d8fab98da1ba311750f6a Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 18 Apr 2022 17:52:49 +0100 Subject: [PATCH 258/317] uncommented line --- src/utils/testing_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/testing_utils.cc b/src/utils/testing_utils.cc index 8a5d8a465..2de404271 100644 --- a/src/utils/testing_utils.cc +++ b/src/utils/testing_utils.cc @@ -57,7 +57,7 @@ std::shared_ptr get_algorithm(const std::string& id, int dim) { if (dim == 1) { hier = get_univariate_nnig_hierarchy(); } else { - // hier = get_multivariate_nnw_hierarchy(dim); + hier = get_multivariate_nnw_hierarchy(dim); } hier->initialize(); algo->set_mixing(mixing); From 4d6642b841c1589ff65f729c33aa288af187ac50 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 18 Apr 2022 19:15:21 +0100 Subject: [PATCH 259/317] const refs --- src/hierarchies/likelihoods/states/base_state.h | 4 ++-- src/hierarchies/likelihoods/states/multi_ls_state.h | 2 +- src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h | 2 +- src/hierarchies/likelihoods/states/uni_ls_state.h | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h index df7d992fd..83d9fb7f0 100644 --- a/src/hierarchies/likelihoods/states/base_state.h +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -17,13 +17,13 @@ class BaseState { virtual Eigen::VectorXd get_unconstrained() { return Eigen::VectorXd(0); } - virtual void set_from_unconstrained(Eigen::VectorXd in) {} + virtual void set_from_unconstrained(const Eigen::VectorXd &in) {} virtual void set_from_proto(const ProtoState &state_, bool update_card) = 0; virtual ProtoState get_as_proto() const = 0; - std::shared_ptr to_proto() { + std::shared_ptr to_proto() const { return std::make_shared(get_as_proto()); } diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index 99230d8b8..dc7306e7a 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -57,7 +57,7 @@ class MultiLS : public BaseState { return multi_ls_to_unconstrained(mean, prec); } - void set_from_unconstrained(Eigen::VectorXd in) override { + void set_from_unconstrained(const Eigen::VectorXd &in) override { std::tie(mean, prec) = multi_ls_to_constrained(in); set_from_constrained(mean, prec); } diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h index ba0db6350..c3a0115a1 100644 --- a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h @@ -50,7 +50,7 @@ class UniLinRegLS : public BaseState { return uni_lin_reg_to_unconstrained(temp); } - void set_from_unconstrained(Eigen::VectorXd in) override { + void set_from_unconstrained(const Eigen::VectorXd &in) override { Eigen::VectorXd temp = uni_lin_reg_to_constrained(in); int dim = in.size() - 1; regression_coeffs = temp.head(dim); diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h index 76a199663..335fc6cf7 100644 --- a/src/hierarchies/likelihoods/states/uni_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -45,7 +45,7 @@ class UniLS : public BaseState { return uni_ls_to_unconstrained(temp); } - void set_from_unconstrained(Eigen::VectorXd in) override { + void set_from_unconstrained(const Eigen::VectorXd &in) override { Eigen::VectorXd temp = uni_ls_to_constrained(in); mean = temp(0); var = temp(1); From 97d45fecc30ee6d67c84657e666ab8c0af00967d Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Thu, 21 Apr 2022 13:04:26 +0100 Subject: [PATCH 260/317] fix --- src/hierarchies/priors/nw_prior_model.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/priors/nw_prior_model.cc b/src/hierarchies/priors/nw_prior_model.cc index cfee03f4c..46c9f6c77 100644 --- a/src/hierarchies/priors/nw_prior_model.cc +++ b/src/hierarchies/priors/nw_prior_model.cc @@ -62,7 +62,7 @@ void NWPriorModel::initialize_hypers() { // for mu0 Eigen::VectorXd mu00 = bayesmix::to_eigen(prior->ngiw_prior().mean_prior().mean()); - unsigned int dim = mu00.size(); + dim = mu00.size(); Eigen::MatrixXd sigma00 = bayesmix::to_eigen(prior->ngiw_prior().mean_prior().var()); // for lambda0 From 7795f5eeee048577324d5b24614edb673b15cab0 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Fri, 22 Apr 2022 10:06:41 +0100 Subject: [PATCH 261/317] make gammagamma great again --- examples/gamma_hierarchy/gamma_likelihood.h | 42 +++++++++---------- examples/gamma_hierarchy/gamma_prior_model.h | 36 +++++++--------- examples/gamma_hierarchy/gammagamma_updater.h | 4 +- src/proto/algorithm_state.proto | 3 +- 4 files changed, 37 insertions(+), 48 deletions(-) diff --git a/examples/gamma_hierarchy/gamma_likelihood.h b/examples/gamma_hierarchy/gamma_likelihood.h index 3f8aac15a..25b3103c5 100644 --- a/examples/gamma_hierarchy/gamma_likelihood.h +++ b/examples/gamma_hierarchy/gamma_likelihood.h @@ -9,11 +9,29 @@ #include "algorithm_state.pb.h" #include "src/hierarchies/likelihoods/base_likelihood.h" +#include "src/hierarchies/likelihoods/states/base_state.h" namespace State { -class Gamma { +class Gamma : public BaseState { public: double shape, rate; + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + ProtoState get_as_proto() const override { + ProtoState out; + out.mutable_general_state()->set_size(2); + out.mutable_general_state()->mutable_data()->Add(shape); + out.mutable_general_state()->mutable_data()->Add(rate); + return out; + } + + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + shape = state_.general_state().data()[0]; + rate = state_.general_state().data()[1]; + } }; } // namespace State @@ -23,16 +41,12 @@ class GammaLikelihood : public BaseLikelihood { ~GammaLikelihood() = default; bool is_multivariate() const override { return false; }; bool is_dependent() const override { return false; }; - void set_state_from_proto(const google::protobuf::Message &state_, - bool update_card = true) override; void clear_summary_statistics() override; // Getters and Setters int get_ndata() const { return ndata; }; double get_shape() const { return state.shape; }; double get_data_sum() const { return data_sum; }; - std::shared_ptr get_state_proto() - const override; protected: double compute_lpdf(const Eigen::RowVectorXd &datum) const override; @@ -45,29 +59,11 @@ class GammaLikelihood : public BaseLikelihood { }; /* DEFINITIONS */ -void GammaLikelihood::set_state_from_proto( - const google::protobuf::Message &state_, bool update_card) { - auto &statecast = downcast_state(state_); - state.rate = statecast.general_state().data()[0]; - if (update_card) set_card(statecast.cardinality()); -} - void GammaLikelihood::clear_summary_statistics() { data_sum = 0; ndata = 0; } -std::shared_ptr -GammaLikelihood::get_state_proto() const { - bayesmix::Vector state_; - state_.mutable_data()->Add(state.shape); - state_.mutable_data()->Add(state.rate); - - auto out = std::make_shared(); - out->mutable_general_state()->CopyFrom(state_); - return out; -} - double GammaLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { return stan::math::gamma_lpdf(datum(0), state.shape, state.rate); } diff --git a/examples/gamma_hierarchy/gamma_prior_model.h b/examples/gamma_hierarchy/gamma_prior_model.h index 3c226df33..a49bafadc 100644 --- a/examples/gamma_hierarchy/gamma_prior_model.h +++ b/examples/gamma_hierarchy/gamma_prior_model.h @@ -7,6 +7,7 @@ #include #include "algorithm_state.pb.h" +#include "gamma_likelihood.h" #include "hierarchy_prior.pb.h" #include "src/hierarchies/priors/base_prior_model.h" #include "src/utils/rng.h" @@ -18,7 +19,7 @@ struct Gamma { } // namespace Hyperparams class GammaPriorModel - : public BasePriorModel { public: using AbstractPriorModel::ProtoHypers; @@ -30,8 +31,7 @@ class GammaPriorModel double lpdf(const google::protobuf::Message &state_) override; - std::shared_ptr sample( - ProtoHypersPtr hier_hypers = nullptr) override; + State::Gamma sample(ProtoHypersPtr hier_hypers = nullptr) override; void update_hypers(const std::vector &states) override { @@ -61,36 +61,30 @@ double GammaPriorModel::lpdf(const google::protobuf::Message &state_) { return stan::math::gamma_lpdf(rate, hypers->rate_alpha, hypers->rate_beta); } -std::shared_ptr GammaPriorModel::sample( - ProtoHypersPtr hier_hypers) { +State::Gamma GammaPriorModel::sample(ProtoHypersPtr hier_hypers) { auto &rng = bayesmix::Rng::Instance().get(); + State::Gamma out; - auto params = (hier_hypers) ? hier_hypers->fake_prior() - : get_hypers_proto()->fake_prior(); + auto params = (hier_hypers) ? hier_hypers->general_state() + : get_hypers_proto()->general_state(); double rate_alpha = params.data()[0]; double rate_beta = params.data()[1]; - double new_rate = stan::math::gamma_rng(rate_alpha, rate_beta, rng); - - bayesmix::AlgorithmState::ClusterState out; - out.mutable_general_state()->mutable_data()->Add(shape); - out.mutable_general_state()->mutable_data()->Add(new_rate); - return std::make_shared(out); + out.shape = shape; + out.rate = stan::math::gamma_rng(rate_alpha, rate_beta, rng); + return out; } void GammaPriorModel::set_hypers_from_proto( const google::protobuf::Message &hypers_) { - auto &hyperscast = downcast_hypers(hypers_); - hypers->rate_alpha = hyperscast.fake_prior().data()[0]; - hypers->rate_beta = hyperscast.fake_prior().data()[1]; + auto &hyperscast = downcast_hypers(hypers_).general_state(); + hypers->rate_alpha = hyperscast.data()[0]; + hypers->rate_beta = hyperscast.data()[1]; }; GammaPriorModel::ProtoHypersPtr GammaPriorModel::get_hypers_proto() const { - bayesmix::Vector hypers_; - hypers_.mutable_data()->Add(hypers->rate_alpha); - hypers_.mutable_data()->Add(hypers->rate_beta); - ProtoHypersPtr out = std::make_shared(); - out->mutable_fake_prior()->CopyFrom(hypers_); + out->mutable_general_state()->mutable_data()->Add(hypers->rate_alpha); + out->mutable_general_state()->mutable_data()->Add(hypers->rate_beta); return out; }; diff --git a/examples/gamma_hierarchy/gammagamma_updater.h b/examples/gamma_hierarchy/gammagamma_updater.h index 310ac65da..dd05b6ffd 100644 --- a/examples/gamma_hierarchy/gammagamma_updater.h +++ b/examples/gamma_hierarchy/gammagamma_updater.h @@ -41,8 +41,8 @@ AbstractUpdater::ProtoHypersPtr GammaGammaUpdater::compute_posterior_hypers( // Proto conversion ProtoHypers out; - out.mutable_fake_prior()->mutable_data()->Add(rate_alpha_new); - out.mutable_fake_prior()->mutable_data()->Add(rate_beta_new); + out.mutable_general_state()->mutable_data()->Add(rate_alpha_new); + out.mutable_general_state()->mutable_data()->Add(rate_beta_new); return std::make_shared(out); } diff --git a/src/proto/algorithm_state.proto b/src/proto/algorithm_state.proto index 2db4cbc47..d106d8cf9 100644 --- a/src/proto/algorithm_state.proto +++ b/src/proto/algorithm_state.proto @@ -38,8 +38,7 @@ message AlgorithmState { message HierarchyHypers { // Current values of the Hyperparameters of the Hierarchy oneof val { - Vector fake_prior = 1; - //EmptyPrior fake_prior = 1; + Vector general_state = 1; NIGDistribution nnig_state = 2; NWDistribution nnw_state = 3; MultiNormalIGDistribution lin_reg_uni_state = 4; From 08d8680b6c694b0fc983ede0504f1acccb1d7771 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Fri, 29 Apr 2022 10:54:49 +0300 Subject: [PATCH 262/317] added docs for states --- .../likelihoods/states/base_state.h | 36 +++++++++++++++++-- src/hierarchies/likelihoods/states/fa_state.h | 11 ++++++ .../likelihoods/states/multi_ls_state.h | 4 +++ .../likelihoods/states/uni_lin_reg_ls_state.h | 3 ++ .../likelihoods/states/uni_ls_state.h | 2 ++ 5 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h index 83d9fb7f0..3a9163271 100644 --- a/src/hierarchies/likelihoods/states/base_state.h +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -9,25 +9,57 @@ namespace State { +//! Abstract base class for a generic state +//! +//! Given a statistical model with likelihood L(y|tau) and prior p(tau) +//! a State class represents the value of tau at a certain MCMC iteration. +//! In addition, each instance stores the cardinality of the number of +//! observations in the model. +//! +//! State classes inheriting from this one should implement the methods +//! `set_from_proto` and `to_proto`, that are used to deserialzie from +//! (and serialize to) a `bayesmix::AlgorithmState::ClusterState` +//! protocol buffer message. +//! +//! Optionally, each state can have an "unconstrained" representation, +//! where a bijective transformation B is applied to `tau`, so that +//! the image of B(tau) is R^d for some d. +//! This is essential for the default updaters such as `RandomWalkUpdater` +//! and `MalaUpdater` to work, but is not necessary for other model-specific +//! updaters. +//! If such a representation is needed, child classes should also implement +//! `get_unconstrained`, `set_from_unconstrained`, and `log_det_jac`. + class BaseState { public: int card; using ProtoState = bayesmix::AlgorithmState::ClusterState; + //! Returns the unconstrained representation x = B(tau) virtual Eigen::VectorXd get_unconstrained() { return Eigen::VectorXd(0); } + //! Sets the current state as tau = B^{-1}(in) + //! @param in the unconstrained representation of the state virtual void set_from_unconstrained(const Eigen::VectorXd &in) {} + //! Returns the log determinant of the jacobian of B^{-1} + virtual double log_det_jac() { return -1; } + + //! Sets the current state from a protobuf object + //! @param state_ a bayesmix::AlgorithmState::ClusterState instance + //! @param update_card if true, the current cardinality is udpdate virtual void set_from_proto(const ProtoState &state_, bool update_card) = 0; + //! Returns a `bayesmix::AlgorithmState::ClusterState` representig the + //! current value of the state virtual ProtoState get_as_proto() const = 0; + //! Returns a shared pointer to `bayesmix::AlgorithmState::ClusterState` + //! representig the current value of the state std::shared_ptr to_proto() const { return std::make_shared(get_as_proto()); } - - virtual double log_det_jac() { return -1; } }; } // namespace State diff --git a/src/hierarchies/likelihoods/states/fa_state.h b/src/hierarchies/likelihoods/states/fa_state.h index 8f299859c..2417a194f 100644 --- a/src/hierarchies/likelihoods/states/fa_state.h +++ b/src/hierarchies/likelihoods/states/fa_state.h @@ -12,6 +12,17 @@ namespace State { +//! State of a Factor Analytic model +//! Y_i = lambda * eta_i + err +//! where `Y_i` is a `p`-dimensional vetor, `eta_i` is a d-dimensional one, +//! `lambda` is a `p x d` matrix and `err` is an error term with mean zero and +//! diagonal covariance matrix `psi`. +//! +//! For faster likelihood evaluation, we store also the `cov_wood` factor and +//! the log determinant of the matrix `lambda * lambda^T + psi`, see +//! the `compute_wood_chol_and_logdet` function for more details. +//! +//! The unconstrained representation for this state is not implemented. class FA : public BaseState { public: Eigen::VectorXd mu, psi; diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index dc7306e7a..8c9309845 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -45,6 +45,10 @@ T multi_ls_log_det_jac( return out; } +//! A univariate location-scale state with parametrization (mean, Cov) +//! where Cov is the covariance matrix. +//! The unconstrained representation corresponds to (mean, B(cov)), for +//! B the `stan::math::cov_matrix_free` transformation. class MultiLS : public BaseState { public: Eigen::VectorXd mean; diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h index c3a0115a1..f017127ca 100644 --- a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h @@ -37,6 +37,9 @@ T uni_lin_reg_log_det_jac(Eigen::Matrix constrained) { return out; } +//! State of a scalar linear regression model with parameters +//! (regression_coeffs, var), where var is the variance of the error term. +//! The unconstrained representation is (regression_coeffs, log(var)). class UniLinRegLS : public BaseState { public: Eigen::VectorXd regression_coeffs; diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h index 335fc6cf7..9aba6e803 100644 --- a/src/hierarchies/likelihoods/states/uni_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -33,6 +33,8 @@ T uni_ls_log_det_jac(Eigen::Matrix constrained) { return out; } +//! A univariate location-scale state with parametrization (mean, var) +//! The unconstrained representation corresponds to (mean, log(var)) class UniLS : public BaseState { public: double mean, var; From 5440e352e9779a4a9d834a4ce9504f6ec9b2157e Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Fri, 29 Apr 2022 13:29:49 +0300 Subject: [PATCH 263/317] parenthesis --- src/hierarchies/likelihoods/states/base_state.h | 4 ++-- src/hierarchies/likelihoods/states/fa_state.h | 2 +- src/hierarchies/likelihoods/states/multi_ls_state.h | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h index 3a9163271..0e04ff182 100644 --- a/src/hierarchies/likelihoods/states/base_state.h +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -17,7 +17,7 @@ namespace State { //! observations in the model. //! //! State classes inheriting from this one should implement the methods -//! `set_from_proto` and `to_proto`, that are used to deserialzie from +//! `set_from_proto()` and `to_proto()`, that are used to deserialzie from //! (and serialize to) a `bayesmix::AlgorithmState::ClusterState` //! protocol buffer message. //! @@ -28,7 +28,7 @@ namespace State { //! and `MalaUpdater` to work, but is not necessary for other model-specific //! updaters. //! If such a representation is needed, child classes should also implement -//! `get_unconstrained`, `set_from_unconstrained`, and `log_det_jac`. +//! `get_unconstrained()`, `set_from_unconstrained()`, and `log_det_jac()`. class BaseState { public: diff --git a/src/hierarchies/likelihoods/states/fa_state.h b/src/hierarchies/likelihoods/states/fa_state.h index 2417a194f..bfb96d842 100644 --- a/src/hierarchies/likelihoods/states/fa_state.h +++ b/src/hierarchies/likelihoods/states/fa_state.h @@ -20,7 +20,7 @@ namespace State { //! //! For faster likelihood evaluation, we store also the `cov_wood` factor and //! the log determinant of the matrix `lambda * lambda^T + psi`, see -//! the `compute_wood_chol_and_logdet` function for more details. +//! the `compute_wood_chol_and_logdet(...)` function for more details. //! //! The unconstrained representation for this state is not implemented. class FA : public BaseState { diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index 8c9309845..346905892 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -48,7 +48,7 @@ T multi_ls_log_det_jac( //! A univariate location-scale state with parametrization (mean, Cov) //! where Cov is the covariance matrix. //! The unconstrained representation corresponds to (mean, B(cov)), for -//! B the `stan::math::cov_matrix_free` transformation. +//! B the `stan::math::cov_matrix_free()` transformation. class MultiLS : public BaseState { public: Eigen::VectorXd mean; From e445ef9b0e3a2e9f36cb96bd78a2572de82eb947 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Sun, 1 May 2022 15:25:52 +0300 Subject: [PATCH 264/317] addressing comments --- src/hierarchies/likelihoods/states/base_state.h | 10 +++++----- src/hierarchies/likelihoods/states/fa_state.h | 2 ++ src/hierarchies/likelihoods/states/multi_ls_state.h | 13 +++++++++++-- src/hierarchies/likelihoods/states/uni_ls_state.h | 7 +++++++ 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h index 0e04ff182..891302574 100644 --- a/src/hierarchies/likelihoods/states/base_state.h +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -17,13 +17,13 @@ namespace State { //! observations in the model. //! //! State classes inheriting from this one should implement the methods -//! `set_from_proto()` and `to_proto()`, that are used to deserialzie from +//! `set_from_proto()` and `to_proto()`, that are used to deserialize from //! (and serialize to) a `bayesmix::AlgorithmState::ClusterState` //! protocol buffer message. //! //! Optionally, each state can have an "unconstrained" representation, //! where a bijective transformation B is applied to `tau`, so that -//! the image of B(tau) is R^d for some d. +//! the image of B is R^d for some d. //! This is essential for the default updaters such as `RandomWalkUpdater` //! and `MalaUpdater` to work, but is not necessary for other model-specific //! updaters. @@ -48,15 +48,15 @@ class BaseState { //! Sets the current state from a protobuf object //! @param state_ a bayesmix::AlgorithmState::ClusterState instance - //! @param update_card if true, the current cardinality is udpdate + //! @param update_card if true, the current cardinality is updated virtual void set_from_proto(const ProtoState &state_, bool update_card) = 0; - //! Returns a `bayesmix::AlgorithmState::ClusterState` representig the + //! Returns a `bayesmix::AlgorithmState::ClusterState` representing the //! current value of the state virtual ProtoState get_as_proto() const = 0; //! Returns a shared pointer to `bayesmix::AlgorithmState::ClusterState` - //! representig the current value of the state + //! representing the current value of the state std::shared_ptr to_proto() const { return std::make_shared(get_as_proto()); } diff --git a/src/hierarchies/likelihoods/states/fa_state.h b/src/hierarchies/likelihoods/states/fa_state.h index bfb96d842..7f7a9a62c 100644 --- a/src/hierarchies/likelihoods/states/fa_state.h +++ b/src/hierarchies/likelihoods/states/fa_state.h @@ -45,6 +45,8 @@ class FA : public BaseState { compute_wood_factors(); } + //! Sets cov_logdet and cov_wood by calling + //! bayesmix::compute_wood_chol_and_logdet() void compute_wood_factors() { auto [cov_wood_, cov_logdet_] = bayesmix::compute_wood_chol_and_logdet(psi_inverse, lambda); diff --git a/src/hierarchies/likelihoods/states/multi_ls_state.h b/src/hierarchies/likelihoods/states/multi_ls_state.h index 346905892..a09221363 100644 --- a/src/hierarchies/likelihoods/states/multi_ls_state.h +++ b/src/hierarchies/likelihoods/states/multi_ls_state.h @@ -10,6 +10,10 @@ namespace State { +//! Returns the unonstrained parametrization from the +//! unconstrained one, i.e. [mean_in[0], B(prec_in[1])] +//! where B is the `stan::math::cov_matrix_free()` +//! transformation. template Eigen::Matrix multi_ls_to_unconstrained( Eigen::Matrix mean_in, @@ -21,6 +25,8 @@ Eigen::Matrix multi_ls_to_unconstrained( return out; } +//! Returns the unonstrained parametrization from the +//! unconstrained one template std::tuple, Eigen::Matrix> @@ -36,6 +42,9 @@ multi_ls_to_constrained(Eigen::Matrix in) { return std::make_tuple(mean, prec); } +//! Returns the log determinant of the jacobian of the map +//! (x, y) -> (x, B(y)), that is the inverse map of the +//! constrained -> unconstrained representation. template T multi_ls_log_det_jac( Eigen::Matrix prec_constrained) { @@ -47,8 +56,8 @@ T multi_ls_log_det_jac( //! A univariate location-scale state with parametrization (mean, Cov) //! where Cov is the covariance matrix. -//! The unconstrained representation corresponds to (mean, B(cov)), for -//! B the `stan::math::cov_matrix_free()` transformation. +//! The unconstrained representation corresponds to (mean, B(cov)), where +//! B is the `stan::math::cov_matrix_free()` transformation. class MultiLS : public BaseState { public: Eigen::VectorXd mean; diff --git a/src/hierarchies/likelihoods/states/uni_ls_state.h b/src/hierarchies/likelihoods/states/uni_ls_state.h index 9aba6e803..c553ef60d 100644 --- a/src/hierarchies/likelihoods/states/uni_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_ls_state.h @@ -10,6 +10,8 @@ namespace State { +//! Returns the constrained parametrization from the +//! unconstrained one, i.e. [in[0], exp(in[1])] template Eigen::Matrix uni_ls_to_constrained( Eigen::Matrix in) { @@ -18,6 +20,8 @@ Eigen::Matrix uni_ls_to_constrained( return out; } +//! Returns the unconstrained parametrization from the +//! constrained one, i.e. [in[0], log(in[1])] template Eigen::Matrix uni_ls_to_unconstrained( Eigen::Matrix in) { @@ -26,6 +30,9 @@ Eigen::Matrix uni_ls_to_unconstrained( return out; } +//! Returns the log determinant of the jacobian of the map +//! (x, y) -> (x, log(y)), that is the inverse map of the +//! constrained -> unconstrained representation. template T uni_ls_log_det_jac(Eigen::Matrix constrained) { T out = 0; From 12831ef3a803adb4ce7c158a32dac100654aff3c Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Sun, 1 May 2022 15:28:30 +0300 Subject: [PATCH 265/317] docs also for linreg --- .../likelihoods/states/uni_lin_reg_ls_state.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h index f017127ca..0b650c878 100644 --- a/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h +++ b/src/hierarchies/likelihoods/states/uni_lin_reg_ls_state.h @@ -11,6 +11,10 @@ namespace State { +//! Returns the constrained parametrization from the +//! unconstrained one, i.e. [a, exp(b)], +//! where `a` is equal to the vector `in` excluding its last +//! element, and `b` is the last element in `in` template Eigen::Matrix uni_lin_reg_to_constrained( Eigen::Matrix in) { @@ -20,6 +24,10 @@ Eigen::Matrix uni_lin_reg_to_constrained( return out; } +//! Returns the unconstrained parametrization from the +//! constrained one, i.e. [a, log(b)] +//! where `a` is equal to the vector `in` excluding its last +//! element, and `b` is the last element in `in` template Eigen::Matrix uni_lin_reg_to_unconstrained( Eigen::Matrix in) { @@ -29,6 +37,9 @@ Eigen::Matrix uni_lin_reg_to_unconstrained( return out; } +//! Returns the log determinant of the jacobian of the map +//! (x, y) -> (x, log(y)), that is the inverse map of the +//! constrained -> unconstrained representation. template T uni_lin_reg_log_det_jac(Eigen::Matrix constrained) { T out = 0; From dfc10363db23765951bacdf55aa09b2edab9f399 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Fri, 29 Apr 2022 13:21:53 +0300 Subject: [PATCH 266/317] basic likelihood docs --- src/hierarchies/likelihoods/abstract_likelihood.h | 14 ++++++++++++++ src/hierarchies/likelihoods/base_likelihood.h | 9 +++++++++ src/hierarchies/likelihoods/fa_likelihood.h | 5 +++++ src/hierarchies/likelihoods/laplace_likelihood.h | 2 ++ .../likelihoods/multi_norm_likelihood.h | 2 ++ src/hierarchies/likelihoods/uni_norm_likelihood.h | 2 ++ 6 files changed, 34 insertions(+) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index bebb3804b..47f1fcc69 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -9,6 +9,20 @@ #include "algorithm_state.pb.h" +//! Abstract class for a generic likelihood +//! +//! This class is the basis for a curiously recurring template pattern (CRTP) +//! for `Mixing` objects, and is solely composed of interface functions for +//! derived classes to use. +//! +//! A likelihood can evaluate the log probability density faction (lpdf) at a +//! certain point given the current value of the parameters, or compute +//! directly the lpdf for the whole cluster. +//! +//! Whenever possible, we store in a `Likelihood` instance also the sufficient +//! statistics of the data allocated to the cluster, in order to speed-up +//! computations. + class AbstractLikelihood { public: //! Default destructor diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index d7da79553..bb96f49b0 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -13,6 +13,15 @@ #include "likelihood_internal.h" #include "src/utils/covariates_getter.h" +//! Base template class of a likelihood object + +//! This class derives from `AbstractLikelihood` and is templated over +//! `Derived` (needed for the curiously recurring template pattern) and +//! `State`: an instance of `BaseState` + +//! @tparam Derived Name of the implemented derived class +//! @tparam State Class name of the container for state values + template class BaseLikelihood : public AbstractLikelihood { public: diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h index 926c022f9..7e68e23ff 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.h +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -12,6 +12,11 @@ #include "base_likelihood.h" #include "states/includes.h" +//! A gaussian factor analytic likelihood, that is +//! Y ~ N_p(mu, Lambda * Lambda^T + Psi) +//! Where Lambda is a `p x d` matrix, usually d << p and `Psi` is a diagonal +//! matrix. + class FALikelihood : public BaseLikelihood { public: FALikelihood() = default; diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 26b9243a1..59473ffb0 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -11,6 +11,8 @@ #include "base_likelihood.h" #include "states/includes.h" +//! A univariate laplace likelihood + class LaplaceLikelihood : public BaseLikelihood { public: diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index 1887c93f4..e90a2a01f 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -12,6 +12,8 @@ #include "base_likelihood.h" #include "states/includes.h" +//! A multivariate normal likelihood + class MultiNormLikelihood : public BaseLikelihood { public: diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 7aba9a959..37d4503e3 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -11,6 +11,8 @@ #include "base_likelihood.h" #include "states/includes.h" +//! A univariate normal likelihood, using the `State::UniLS` state. + class UniNormLikelihood : public BaseLikelihood { public: From b9a189f5930e0333f67d96cb099a0c9455b86a6c Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Sun, 1 May 2022 15:19:02 +0300 Subject: [PATCH 267/317] more docs --- src/hierarchies/likelihoods/multi_norm_likelihood.h | 7 +++++++ src/hierarchies/likelihoods/uni_lin_reg_likelihood.h | 11 +++++++++++ src/hierarchies/likelihoods/uni_norm_likelihood.h | 7 +++++++ 3 files changed, 25 insertions(+) diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index e90a2a01f..5c96104f5 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -13,6 +13,13 @@ #include "states/includes.h" //! A multivariate normal likelihood +//! +//! Represents the model: +//! y_1, ..., y_m ~ N(mu, Cov) +//! where (mu, Cov) are stored in a `State::MultiLS` state +//! +//! The sufficient statistics store are the sum of the y_i's +//! and the sum of y_i^T } y_i. class MultiNormLikelihood : public BaseLikelihood { diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h index 4f8d7eaea..96600d4dc 100644 --- a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h @@ -11,6 +11,17 @@ #include "base_likelihood.h" #include "states/includes.h" +//! A scalar linear regression model +//! +//! Represents the model: +//! y_i ~ N(x_i^T * reg_coeffs, var) +//! where (reg_coeffs, var) are stored in a `State::UniLinRegLS` state +//! +//! The sufficient statistics stored are the +//! 1) sum of y_i^2 +//! 2) sum of x_i^T x_i +//! 3) sum of y_i x_i^T + class UniLinRegLikelihood : public BaseLikelihood { public: diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 37d4503e3..937ae8ec9 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -12,6 +12,13 @@ #include "states/includes.h" //! A univariate normal likelihood, using the `State::UniLS` state. +//! +//! Represents the model: +//! y_1, ..., y_m ~ N(mu, var) +//! where (mu, var) are stored in a `State::UniLS` state +//! +//! The sufficient statistics store are the sum of the y_i's +//! and the sum of y_i^2. class UniNormLikelihood : public BaseLikelihood { From 7379ad47cc30f3f452043f7e0b000d18139419bd Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Sun, 1 May 2022 15:55:35 +0300 Subject: [PATCH 268/317] more docs --- src/hierarchies/likelihoods/fa_likelihood.h | 6 ++++++ src/hierarchies/likelihoods/laplace_likelihood.h | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h index 7e68e23ff..0c5e4938d 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.h +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -16,6 +16,12 @@ //! Y ~ N_p(mu, Lambda * Lambda^T + Psi) //! Where Lambda is a `p x d` matrix, usually d << p and `Psi` is a diagonal //! matrix. +//! +//! Parameters are stored in a `State::FA` state. +//! We store as summary statistics the sum of the y_i's, but it is +//! not sufficient for all the updates involved. Therefore, all the +//! observations allocated to a cluster are processed when computing the +//! cluster lpdf. class FALikelihood : public BaseLikelihood { public: diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 59473ffb0..6f222ca93 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -12,6 +12,16 @@ #include "states/includes.h" //! A univariate laplace likelihood +//! +//! Represents the model: +//! y_i ~ Laplace(mu, var) +//! where mu is the mean and center of the distribution +//! and var is the variance. The scale is then sqrt(var / 2) +//! These parameters are stored in a `State::UniLS` state +//! +//! Since the Laplce likelihood does not have sufficient statistics +//! other than the whole sample, the `update_sum_stats` method +//! does nothing. class LaplaceLikelihood : public BaseLikelihood { From 70de8aff636bd8ce5f81f054d401acdc9cf2a211 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 2 May 2022 16:18:44 +0100 Subject: [PATCH 269/317] typo --- src/hierarchies/likelihoods/multi_norm_likelihood.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index 5c96104f5..eb7eabaed 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -19,7 +19,7 @@ //! where (mu, Cov) are stored in a `State::MultiLS` state //! //! The sufficient statistics store are the sum of the y_i's -//! and the sum of y_i^T } y_i. +//! and the sum of y_i^T y_i. class MultiNormLikelihood : public BaseLikelihood { From bf158ce9f6fb67b521df757364734761094ec389 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Sat, 7 May 2022 08:57:31 +0200 Subject: [PATCH 270/317] typos --- src/hierarchies/likelihoods/laplace_likelihood.h | 2 +- src/hierarchies/likelihoods/multi_norm_likelihood.h | 2 +- src/hierarchies/likelihoods/uni_norm_likelihood.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 6f222ca93..6b9196b8a 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -19,7 +19,7 @@ //! and var is the variance. The scale is then sqrt(var / 2) //! These parameters are stored in a `State::UniLS` state //! -//! Since the Laplce likelihood does not have sufficient statistics +//! Since the Laplace likelihood does not have sufficient statistics //! other than the whole sample, the `update_sum_stats` method //! does nothing. diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index eb7eabaed..13f338faa 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -18,7 +18,7 @@ //! y_1, ..., y_m ~ N(mu, Cov) //! where (mu, Cov) are stored in a `State::MultiLS` state //! -//! The sufficient statistics store are the sum of the y_i's +//! The sufficient statistics stored are the sum of the y_i's //! and the sum of y_i^T y_i. class MultiNormLikelihood diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index 937ae8ec9..eb17f6853 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -17,7 +17,7 @@ //! y_1, ..., y_m ~ N(mu, var) //! where (mu, var) are stored in a `State::UniLS` state //! -//! The sufficient statistics store are the sum of the y_i's +//! The sufficient statistics stored are the sum of the y_i's //! and the sum of y_i^2. class UniNormLikelihood From f9531441545a4473ef0aec37a78a07e090497cfb Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Sat, 7 May 2022 09:16:40 +0200 Subject: [PATCH 271/317] updater docs --- src/hierarchies/updaters/abstract_updater.h | 6 ++++ src/hierarchies/updaters/fa_updater.h | 5 ++++ src/hierarchies/updaters/mala_updater.h | 27 +++++++++++++++++ src/hierarchies/updaters/metropolis_updater.h | 29 +++++++++++++++++++ src/hierarchies/updaters/mnig_updater.h | 9 ++++++ src/hierarchies/updaters/nnig_updater.h | 9 ++++++ src/hierarchies/updaters/nnw_updater.h | 9 ++++++ src/hierarchies/updaters/nnxig_updater.h | 9 ++++++ .../updaters/random_walk_updater.h | 25 ++++++++++++++++ .../updaters/semi_conjugate_updater.h | 21 ++++++++++++++ 10 files changed, 149 insertions(+) diff --git a/src/hierarchies/updaters/abstract_updater.h b/src/hierarchies/updaters/abstract_updater.h index b1c8b849c..eaa19ca58 100644 --- a/src/hierarchies/updaters/abstract_updater.h +++ b/src/hierarchies/updaters/abstract_updater.h @@ -4,6 +4,12 @@ #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" +//! Base class for an Updater + +//! An updater is a class able to sample from the full conditional +//! distribution of an `Hierarchy`, coming from the product of a +//! `Likelihood` and a `Prior`, possibly using a Metropolis-Hastings +//! algorithm. class AbstractUpdater { public: // Type aliases diff --git a/src/hierarchies/updaters/fa_updater.h b/src/hierarchies/updaters/fa_updater.h index 671cbaae5..ec70ec997 100644 --- a/src/hierarchies/updaters/fa_updater.h +++ b/src/hierarchies/updaters/fa_updater.h @@ -8,6 +8,11 @@ #include "src/hierarchies/priors/hyperparams.h" #include "src/utils/proto_utils.h" +//! Updater specific for the `FAHierachy`. +//! See Bhattacharya, Anirban, and David B. Dunson. +//! "Sparse Bayesian infinite factor models." Biometrika (2011): 291-306. +//! for further details + class FAUpdater : public AbstractUpdater { public: FAUpdater() = default; diff --git a/src/hierarchies/updaters/mala_updater.h b/src/hierarchies/updaters/mala_updater.h index a619780f2..57a97246c 100644 --- a/src/hierarchies/updaters/mala_updater.h +++ b/src/hierarchies/updaters/mala_updater.h @@ -5,6 +5,17 @@ #include "metropolis_updater.h" +//! Metropolis Adjusted Langevin Algorithm. +//! +//! This class requires that the Hierarchy's state implements +//! the `get_unconstrained()`, `set_from_unconstrained()` and +//! `log_det_jac` functions. +//! +//! Given the current value of the unconstrained parameters x, a new +//! value is proposed from +//! x_new ~ N(x + step_size * grad(full_cond)(x), sqrt(2 step_size) * I) +//! and then either accepted (in which case the hierarchy's state is +//! set to x_new) or rejected. class MalaUpdater : public MetropolisUpdater { protected: double step_size; @@ -15,6 +26,13 @@ class MalaUpdater : public MetropolisUpdater { MalaUpdater(double step_size) : step_size(step_size) {} + //! Samples from the proposal distribution + //! @param curr_state the current state (unconstrained parametrization) + //! @param like instance of likelihood + //! @param prior instance of prior + //! @param target_lpdf either double or stan::math::var. Needed for + //! stan's automatic differentiation. It will be + //! filled with the lpdf af the 'curr_state' Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, AbstractLikelihood &like, AbstractPriorModel &prior, @@ -31,6 +49,14 @@ class MalaUpdater : public MetropolisUpdater { return curr_state + step_size * grad + noise; } + //! Evaluates the log probability density function of the proposal + //! @param prop_state the proposed state (at which to evaluate the lpdf) + //! @param curr_state the current state (unconstrained parametrization) + //! @param like instance of likelihood + //! @param prior instance of prior + //! @param target_lpdf either double or stan::math::var. Needed for + //! stan's automatic differentiation. It will be + //! filled with the lpdf af 'curr_state' double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, AbstractLikelihood &like, AbstractPriorModel &prior, target_lpdf_unconstrained &target_lpdf) { @@ -48,6 +74,7 @@ class MalaUpdater : public MetropolisUpdater { return out; } + //! Returns a shared_ptr to a new instance of `this` std::shared_ptr clone() const { auto out = std::make_shared(static_cast(*this)); diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h index 9cc662395..ae30aebf8 100644 --- a/src/hierarchies/updaters/metropolis_updater.h +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -4,11 +4,40 @@ #include "abstract_updater.h" #include "target_lpdf_unconstrained.h" +//! Base class for updaters using a Metropolis-Hastings algorithm +//! +//! This class serves as the base for a CRTP. +//! Children of this class should implement the methods +//! template +//! Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, +//! AbstractLikelihood &like, +//! AbstractPriorModel &prior, F +//! &target_lpdf) +//! and +//! template +//! double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd +//! curr_state, +//! AbstractLikelihood &like, AbstractPriorModel +//! &prior, +//! F &target_lpdf) +//! where the template parameter is neeeded to allow the use of stan's +//! automatic differentiation if the gradient of the full conditional is +//! required. template class MetropolisUpdater : public AbstractUpdater { public: + //! Samples from the full conditional distribution using a + //! Metropolis-Hasings step void draw(AbstractLikelihood &like, AbstractPriorModel &prior, bool update_params) override { + if (update_params) { + throw std::runtime_error( + "'update_params' can be True only when using instances of" + " 'SemiConjugateUpdater'. This is likely caused by" + " using a nonconjugate hierarchy (or a nonconjugate updater)" + " in a marginal algorithm such as 'Neal3'."); + } + target_lpdf_unconstrained target_lpdf(&like, &prior); Eigen::VectorXd curr_state = like.get_unconstrained_state(); Eigen::VectorXd prop_state = diff --git a/src/hierarchies/updaters/mnig_updater.h b/src/hierarchies/updaters/mnig_updater.h index 4f7bb2d35..490558d1a 100644 --- a/src/hierarchies/updaters/mnig_updater.h +++ b/src/hierarchies/updaters/mnig_updater.h @@ -5,6 +5,15 @@ #include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h" #include "src/hierarchies/priors/mnig_prior_model.h" +//! Updater specific for the `UniLinRegLikelihood` used in combination +//! with `MNIGPriorModel`, that is the model +//! y_i | beta, sigsq ~ N(beta^T x_i, sigsq) +//! beta | sigsq ~ N_p(mu0, sigsq * V^{-1}) +//! sigsq ~ InvGamma(a, b) +//! +//! It exploits the conjugacy of the model to sample the full conditional of +//! (beta, sigsq) by calling `MNIGPriorModel::sample` with updated parameters + class MNIGUpdater : public SemiConjugateUpdater { public: diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index 8a7f52b2d..058a03fe5 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -5,6 +5,15 @@ #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/hierarchies/priors/nig_prior_model.h" +//! Updater specific for the `UniNormLikelihood` used in combination +//! with `NIGPriorModel`, that is the model +//! y_i | mu, sigsq ~ N(mu, sigsq) +//! mu | sigsq ~ N(mu0, sigsq / lambda) +//! sigsq ~ InvGamma(a, b) +//! +//! It exploits the conjugacy of the model to sample the full conditional of +//! (mu, sigsq) by calling `NIGPriorModel::sample` with updated parameters + class NNIGUpdater : public SemiConjugateUpdater { public: diff --git a/src/hierarchies/updaters/nnw_updater.h b/src/hierarchies/updaters/nnw_updater.h index bd7977acd..d3c310b1f 100644 --- a/src/hierarchies/updaters/nnw_updater.h +++ b/src/hierarchies/updaters/nnw_updater.h @@ -5,6 +5,15 @@ #include "src/hierarchies/likelihoods/multi_norm_likelihood.h" #include "src/hierarchies/priors/nw_prior_model.h" +//! Updater specific for the `MultiNormLikelihood` used in combination +//! with `NWPriorModel`, that is the model +//! y_i | mu, Sigma ~ Nd(mu, Sigma) +//! mu | Sigma ~ N_d(mu0, sigsq / lambda) +//! Sigma^{-1} ~ Wishart(nu, Psi) +//! +//! It exploits the conjugacy of the model to sample the full conditional of +//! (mu, sigsq) by calling `NWPriorModel::sample` with updated parameters + class NNWUpdater : public SemiConjugateUpdater { public: diff --git a/src/hierarchies/updaters/nnxig_updater.h b/src/hierarchies/updaters/nnxig_updater.h index 52d8f0a45..fee62d5a9 100644 --- a/src/hierarchies/updaters/nnxig_updater.h +++ b/src/hierarchies/updaters/nnxig_updater.h @@ -5,6 +5,15 @@ #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/hierarchies/priors/nxig_prior_model.h" +//! Updater specific for the `UniNormLikelihood` used in combination +//! with `NxIGPriorModel`, that is the model +//! y_i | mu, sigsq ~ N(mu, sigsq) +//! mu ~ N(mu0, s0) +//! sigsq ~ InvGamma(a, b) +//! +//! It exploits the semi-conjugacy of the model to sample the full conditional +//! of (mu, sigsq) by calling `NxIGPriorModel::sample` with updated parameters + class NNxIGUpdater : public SemiConjugateUpdater { public: diff --git a/src/hierarchies/updaters/random_walk_updater.h b/src/hierarchies/updaters/random_walk_updater.h index b9559a743..a43d7a6f2 100644 --- a/src/hierarchies/updaters/random_walk_updater.h +++ b/src/hierarchies/updaters/random_walk_updater.h @@ -3,6 +3,17 @@ #include "metropolis_updater.h" +//! Metropolis-Hastings updater using an isotropic proposal function +//! centered in the current value of the parameters (unconstrained). +//! This class requires that the Hierarchy's state implements +//! the `get_unconstrained()`, `set_from_unconstrained()` and +//! `log_det_jac` functions. +//! +//! Given the current value of the unconstrained parameters x, a new +//! value is proposed from +//! x_new ~ N(x_new, step_size * I) +//! and then either accepted (in which case the hierarchy's state is +//! set to x_new) or rejected. class RandomWalkUpdater : public MetropolisUpdater { protected: double step_size; @@ -13,6 +24,12 @@ class RandomWalkUpdater : public MetropolisUpdater { RandomWalkUpdater(double step_size) : step_size(step_size) {} + //! Samples from the proposal distribution + //! @param curr_state the current state (unconstrained parametrization) + //! @param like instance of likelihood + //! @param prior instance of prior + //! @param target_lpdf either double or stan::math::var. Needed for + //! stan's automatic differentiation. It is not used here. template Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, AbstractLikelihood &like, @@ -25,6 +42,13 @@ class RandomWalkUpdater : public MetropolisUpdater { return curr_state + step; } + //! Evaluates the log probability density function of the proposal + //! @param prop_state the proposed state (at which to evaluate the lpdf) + //! @param curr_state the current state (unconstrained parametrization) + //! @param like instance of likelihood + //! @param prior instance of prior + //! @param target_lpdf either double or stan::math::var. Needed for + //! stan's automatic differentiation. It is not used here. template double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, AbstractLikelihood &like, AbstractPriorModel &prior, @@ -36,6 +60,7 @@ class RandomWalkUpdater : public MetropolisUpdater { return out; } + //! Returns a shared_ptr to a new instance of `this` std::shared_ptr clone() const { auto out = std::make_shared( static_cast(*this)); diff --git a/src/hierarchies/updaters/semi_conjugate_updater.h b/src/hierarchies/updaters/semi_conjugate_updater.h index 6517025a0..41e3bc23a 100644 --- a/src/hierarchies/updaters/semi_conjugate_updater.h +++ b/src/hierarchies/updaters/semi_conjugate_updater.h @@ -7,6 +7,25 @@ #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" +//! Updater for semi-conjugate hierarchies. +//! +//! We say that a hierarchy is semi-conjugate if the full conditionals +//! of each parameter is in the same parametric family of the prior +//! distribution of that parameter. +//! +//! As a consequence, sampling from the full conditional can be done +//! by calling the `sample` method from the `PriorModel` class, with +//! updater hyperparameters +//! +//! Classes inheriting from this one should only implement the +//! `compute_posterior_hypers(...)` member function +//! +//! This class is templated with respect to +//! @tparam Likelihood: the likelihood of the hierarchy, instance of +//! `AbstractLikelihood` +//! @tparam PriorModel: the prior of the hierarchy, instance of +//! `AbstractPriorModel` + template class SemiConjugateUpdater : public AbstractUpdater { public: @@ -17,6 +36,8 @@ class SemiConjugateUpdater : public AbstractUpdater { void draw(AbstractLikelihood& like, AbstractPriorModel& prior, bool update_params) override; + //! Used by algorithms such as Neal3 and SplitMerge + //! It stores the hyperparameters computed by `compute_posterior_hypers` void save_posterior_hypers(ProtoHypersPtr post_hypers_) override; protected: From eb5372cdd3b9c05635b4f0137afd4d20aed9ab79 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Sat, 7 May 2022 16:11:05 +0200 Subject: [PATCH 272/317] more docs --- src/hierarchies/updaters/target_lpdf_unconstrained.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/hierarchies/updaters/target_lpdf_unconstrained.h b/src/hierarchies/updaters/target_lpdf_unconstrained.h index 9adc8cc02..803900904 100644 --- a/src/hierarchies/updaters/target_lpdf_unconstrained.h +++ b/src/hierarchies/updaters/target_lpdf_unconstrained.h @@ -4,6 +4,10 @@ #include "src/hierarchies/likelihoods/abstract_likelihood.h" #include "src/hierarchies/priors/abstract_prior_model.h" +//! Functor that computes the log-full conditional distribution +//! of a specific hierarchy. +//! Used by metropolis-like updaters especially when the gradient +//! of the target_lpdf if required class target_lpdf_unconstrained { protected: AbstractLikelihood* like; @@ -14,6 +18,9 @@ class target_lpdf_unconstrained { AbstractPriorModel* prior) : like(like), prior(prior) {} + //! Computes the log-full conditional that is simply the + //! sum of `cluster_lpdf_from_unconstrained` in `AbstractLikelihood` + //! and `lpdf_from_unconstrained` in `AbstractPriorModel` template T operator()(const Eigen::Matrix& x) const { return like->cluster_lpdf_from_unconstrained(x) + From 92e9915e191394fee1857b3c78cce20b1c598a5c Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 10 May 2022 17:54:59 +0200 Subject: [PATCH 273/317] comments --- src/hierarchies/updaters/fa_updater.h | 1 - src/hierarchies/updaters/mala_updater.h | 6 +++--- src/hierarchies/updaters/metropolis_updater.h | 14 +++++++------- src/hierarchies/updaters/mnig_updater.h | 10 +++++----- src/hierarchies/updaters/nnw_updater.h | 2 +- src/hierarchies/updaters/random_walk_updater.h | 2 +- 6 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/hierarchies/updaters/fa_updater.h b/src/hierarchies/updaters/fa_updater.h index ec70ec997..b02bdbfca 100644 --- a/src/hierarchies/updaters/fa_updater.h +++ b/src/hierarchies/updaters/fa_updater.h @@ -12,7 +12,6 @@ //! See Bhattacharya, Anirban, and David B. Dunson. //! "Sparse Bayesian infinite factor models." Biometrika (2011): 291-306. //! for further details - class FAUpdater : public AbstractUpdater { public: FAUpdater() = default; diff --git a/src/hierarchies/updaters/mala_updater.h b/src/hierarchies/updaters/mala_updater.h index 57a97246c..8e8d8cdc5 100644 --- a/src/hierarchies/updaters/mala_updater.h +++ b/src/hierarchies/updaters/mala_updater.h @@ -9,7 +9,7 @@ //! //! This class requires that the Hierarchy's state implements //! the `get_unconstrained()`, `set_from_unconstrained()` and -//! `log_det_jac` functions. +//! `log_det_jac()` functions. //! //! Given the current value of the unconstrained parameters x, a new //! value is proposed from @@ -32,7 +32,7 @@ class MalaUpdater : public MetropolisUpdater { //! @param prior instance of prior //! @param target_lpdf either double or stan::math::var. Needed for //! stan's automatic differentiation. It will be - //! filled with the lpdf af the 'curr_state' + //! filled with the lpdf at the 'curr_state' Eigen::VectorXd sample_proposal(Eigen::VectorXd curr_state, AbstractLikelihood &like, AbstractPriorModel &prior, @@ -56,7 +56,7 @@ class MalaUpdater : public MetropolisUpdater { //! @param prior instance of prior //! @param target_lpdf either double or stan::math::var. Needed for //! stan's automatic differentiation. It will be - //! filled with the lpdf af 'curr_state' + //! filled with the lpdf at 'curr_state' double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd curr_state, AbstractLikelihood &like, AbstractPriorModel &prior, target_lpdf_unconstrained &target_lpdf) { diff --git a/src/hierarchies/updaters/metropolis_updater.h b/src/hierarchies/updaters/metropolis_updater.h index ae30aebf8..e4bb93a42 100644 --- a/src/hierarchies/updaters/metropolis_updater.h +++ b/src/hierarchies/updaters/metropolis_updater.h @@ -15,19 +15,19 @@ //! &target_lpdf) //! and //! template -//! double proposal_lpdf(Eigen::VectorXd prop_state, Eigen::VectorXd -//! curr_state, -//! AbstractLikelihood &like, AbstractPriorModel -//! &prior, -//! F &target_lpdf) -//! where the template parameter is neeeded to allow the use of stan's +//! double proposal_lpdf(Eigen::VectorXd prop_state, +//! Eigen::VectorXd curr_state, +//! AbstractLikelihood &like, +//! AbstractPriorModel &prior, +//! F &target_lpdf) +//! where the template parameter is needed to allow the use of stan's //! automatic differentiation if the gradient of the full conditional is //! required. template class MetropolisUpdater : public AbstractUpdater { public: //! Samples from the full conditional distribution using a - //! Metropolis-Hasings step + //! Metropolis-Hastings step void draw(AbstractLikelihood &like, AbstractPriorModel &prior, bool update_params) override { if (update_params) { diff --git a/src/hierarchies/updaters/mnig_updater.h b/src/hierarchies/updaters/mnig_updater.h index 490558d1a..17de64502 100644 --- a/src/hierarchies/updaters/mnig_updater.h +++ b/src/hierarchies/updaters/mnig_updater.h @@ -7,13 +7,13 @@ //! Updater specific for the `UniLinRegLikelihood` used in combination //! with `MNIGPriorModel`, that is the model -//! y_i | beta, sigsq ~ N(beta^T x_i, sigsq) -//! beta | sigsq ~ N_p(mu0, sigsq * V^{-1}) -//! sigsq ~ InvGamma(a, b) +//! y_i | reg_coeffs, var ~ N(reg_coeffs^T x_i, var) +//! reg_coeffs | var ~ N_p(mu0, sigsq * V^{-1}) +//! var ~ InvGamma(a, b) //! //! It exploits the conjugacy of the model to sample the full conditional of -//! (beta, sigsq) by calling `MNIGPriorModel::sample` with updated parameters - +//! (reg_coeffs, var) by calling `MNIGPriorModel::sample` with updated +//! parameters class MNIGUpdater : public SemiConjugateUpdater { public: diff --git a/src/hierarchies/updaters/nnw_updater.h b/src/hierarchies/updaters/nnw_updater.h index d3c310b1f..d7e07c48c 100644 --- a/src/hierarchies/updaters/nnw_updater.h +++ b/src/hierarchies/updaters/nnw_updater.h @@ -12,7 +12,7 @@ //! Sigma^{-1} ~ Wishart(nu, Psi) //! //! It exploits the conjugacy of the model to sample the full conditional of -//! (mu, sigsq) by calling `NWPriorModel::sample` with updated parameters +//! (mu, Sigma) by calling `NWPriorModel::sample` with updated parameters class NNWUpdater : public SemiConjugateUpdater { diff --git a/src/hierarchies/updaters/random_walk_updater.h b/src/hierarchies/updaters/random_walk_updater.h index a43d7a6f2..dae7653ea 100644 --- a/src/hierarchies/updaters/random_walk_updater.h +++ b/src/hierarchies/updaters/random_walk_updater.h @@ -7,7 +7,7 @@ //! centered in the current value of the parameters (unconstrained). //! This class requires that the Hierarchy's state implements //! the `get_unconstrained()`, `set_from_unconstrained()` and -//! `log_det_jac` functions. +//! `log_det_jac()` functions. //! //! Given the current value of the unconstrained parameters x, a new //! value is proposed from From 9111842a3433dcfe542a609032012920dce4cb28 Mon Sep 17 00:00:00 2001 From: AleCarminati Date: Mon, 11 Apr 2022 15:45:30 +0200 Subject: [PATCH 274/317] Corrected segfault in for cycle --- src/algorithms/split_and_merge_algorithm.cc | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/algorithms/split_and_merge_algorithm.cc b/src/algorithms/split_and_merge_algorithm.cc index 2c60a3723..4115723ce 100644 --- a/src/algorithms/split_and_merge_algorithm.cc +++ b/src/algorithms/split_and_merge_algorithm.cc @@ -269,13 +269,10 @@ void SplitAndMergeAlgorithm::proposal_update_allocations( data_to_move_idx = unique_values[label_old_cluster]->get_data_idx(); } - auto curr_it = data_to_move_idx.cbegin(); - auto next_it = curr_it; - next_it++; - auto end_it = data_to_move_idx.cend(); - for (; curr_it != end_it; next_it++, curr_it++) { - const unsigned int curr_idx = *curr_it; - if (next_it == end_it) { + for (auto it = data_to_move_idx.cbegin(); it != data_to_move_idx.cend(); + it++) { + const unsigned int curr_idx = *it; + if (it == (--data_to_move_idx.cend())) { if (split) { unique_values[label_old_cluster]->remove_datum( curr_idx, data.row(curr_idx), update_hierarchy_params()); From c8a447052d1fef29814e289db41315662dcccc04 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 10 May 2022 10:39:01 +0200 Subject: [PATCH 275/317] Fixing doc typos --- src/hierarchies/likelihoods/abstract_likelihood.h | 2 +- src/hierarchies/likelihoods/base_likelihood.h | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hierarchies/likelihoods/abstract_likelihood.h b/src/hierarchies/likelihoods/abstract_likelihood.h index 47f1fcc69..38425829a 100644 --- a/src/hierarchies/likelihoods/abstract_likelihood.h +++ b/src/hierarchies/likelihoods/abstract_likelihood.h @@ -12,7 +12,7 @@ //! Abstract class for a generic likelihood //! //! This class is the basis for a curiously recurring template pattern (CRTP) -//! for `Mixing` objects, and is solely composed of interface functions for +//! for `Likelihood` objects, and is solely composed of interface functions for //! derived classes to use. //! //! A likelihood can evaluate the log probability density faction (lpdf) at a diff --git a/src/hierarchies/likelihoods/base_likelihood.h b/src/hierarchies/likelihoods/base_likelihood.h index bb96f49b0..40ca9dd1d 100644 --- a/src/hierarchies/likelihoods/base_likelihood.h +++ b/src/hierarchies/likelihoods/base_likelihood.h @@ -13,12 +13,12 @@ #include "likelihood_internal.h" #include "src/utils/covariates_getter.h" -//! Base template class of a likelihood object - +//! Base template class of a `Likelihood` object +//! //! This class derives from `AbstractLikelihood` and is templated over //! `Derived` (needed for the curiously recurring template pattern) and //! `State`: an instance of `BaseState` - +//! //! @tparam Derived Name of the implemented derived class //! @tparam State Class name of the container for state values From 14c855cf76b2d63742d24f680164482e126f7e0a Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 10 May 2022 10:39:46 +0200 Subject: [PATCH 276/317] docs for prior models --- src/hierarchies/priors/abstract_prior_model.h | 13 +++++++++++++ src/hierarchies/priors/base_prior_model.h | 14 ++++++++++++++ src/hierarchies/priors/fa_prior_model.h | 8 ++++++++ src/hierarchies/priors/mnig_prior_model.h | 5 +++++ src/hierarchies/priors/nig_prior_model.h | 7 +++++++ src/hierarchies/priors/nw_prior_model.h | 6 ++++++ src/hierarchies/priors/nxig_prior_model.h | 4 ++++ 7 files changed, 57 insertions(+) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index c9ca2f01d..8c240cde6 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -11,6 +11,19 @@ #include "src/hierarchies/likelihoods/states/includes.h" #include "src/utils/rng.h" +//! Abstract class for a generic prior model +//! +//! This class is the basis for a curiously recurring template pattern (CRTP) +//! for `PriorModel` objects, ad it is solely composed of interface functions +//! for derived classes to use. +//! +//! A prior model represents the prior for the parameters in the likelihood. +//! Hence, it can evaluate the log probability density function (lpdf) for a +//! given parameter state. +//! +//! We also store a pointer to the protobuf object that represents the type of +//! prior used fot the parameters in the likelihood. + class AbstractPriorModel { public: // Useful type aliases diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index cda5b6f28..0a3f19e16 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -15,6 +15,20 @@ #include "prior_model_internal.h" #include "src/utils/rng.h" +//! Base template class of a `PriorModel` object +//! +//! This class derives from `AbstractPriorModel` and is templated over +//! `Derived` (needed for the curiously recurring template pattern), `State` +//! (an instance of `Basestate`), `HyperParams` (a struct representing the +//! hyperparameters, see `hyperparams.h`) and `Prior`: a protobuf message +//! representing the type of prior imposed on the hyperparameters. +//! +//! @tparam Derived Name of the implemented derived class +//! @tparam State Class name of the container for state values +//! @tparam HyperParams Class name of the container for hyperparameters +//! @tparam Prior Class name of the protobuf message for the prior on the +//! hyperparameters. + template class BasePriorModel : public AbstractPriorModel { public: diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index dd379a414..959a62a98 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -10,6 +10,14 @@ #include "hyperparams.h" #include "src/utils/rng.h" +//! A prior model for the factor analyzers likelihood, that is +//! mu_j ~ N(mutilde_j, psi) j=1,...,p +//! Lambda ~ DL(alpha) +//! Sigma = diag(sigsq_1,...,sigsq_p) +//! sigsq_j ~ IG(a,b) j=1,...,p +//! Where DL is the Dirichlet-Laplace distribution. See Bhattacharya A., Pati +//! D, Pillai N.S., Dunson D.B. (2015). JASA 110(512), 1479–1490 for details. + class FAPriorModel : public BasePriorModel { diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index b1a5aff66..9f82f8e60 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -10,6 +10,11 @@ #include "hyperparams.h" #include "src/utils/rng.h" +//! A conjugate prior model for the scalar linear regression likelihood, that +//! is +//! reg_coeffs | var ~ N_p(mu, sigsq * Lambda^-1) +//! sigsq ~ IG(a,b) + class MNIGPriorModel : public BasePriorModel { diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index 45b61189e..fe517d350 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -11,6 +11,13 @@ #include "hyperparams.h" #include "src/utils/rng.h" +//! A conjugate prior model for the univariate normal likelihood, that is +//! mu | var ~ N(mu0, var / lambda) +//! var ~ IG(a,b) +//! With several possibilies for hyper-priors on mu and var. We have considered +//! a normal prior for mu0 and a NGG for (mu0, a, b) in addition to fixing +//! prior hyperparameters. + class NIGPriorModel : public BasePriorModel { diff --git a/src/hierarchies/priors/nw_prior_model.h b/src/hierarchies/priors/nw_prior_model.h index e72f0c99a..1f86ff6ec 100644 --- a/src/hierarchies/priors/nw_prior_model.h +++ b/src/hierarchies/priors/nw_prior_model.h @@ -11,6 +11,12 @@ #include "hyperparams.h" #include "src/utils/rng.h" +//! A conjugate prior model for the multivariate normal likelihood, that is +//! mu | Sigma ~ N_p(mu0, Sigma / lambda) +//! Sigma ~ IW(nu0,Psi0) +//! With some options for hyper-priors on mu and Sigma. We have considered a +//! normal prior for mu0 in addition to fixing prior hyperparameters. + class NWPriorModel : public BasePriorModel { diff --git a/src/hierarchies/priors/nxig_prior_model.h b/src/hierarchies/priors/nxig_prior_model.h index 8f5e3f036..a7bcfda65 100644 --- a/src/hierarchies/priors/nxig_prior_model.h +++ b/src/hierarchies/priors/nxig_prior_model.h @@ -11,6 +11,10 @@ #include "hyperparams.h" #include "src/utils/rng.h" +//! A semi-conjugate prior model for the univariate normal likelihood, that is +//! mu ~ N(mu0, var0) +//! var ~ IG(a,b) + class NxIGPriorModel : public BasePriorModel { From 3b952ab937ea964b4bbe06ea8299aebe594b9906 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 10 May 2022 14:11:13 +0200 Subject: [PATCH 277/317] More docs --- src/hierarchies/priors/abstract_prior_model.h | 2 +- src/hierarchies/priors/base_prior_model.h | 2 +- src/hierarchies/priors/mnig_prior_model.h | 3 +-- src/hierarchies/priors/nig_prior_model.h | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/hierarchies/priors/abstract_prior_model.h b/src/hierarchies/priors/abstract_prior_model.h index 8c240cde6..42b7106a5 100644 --- a/src/hierarchies/priors/abstract_prior_model.h +++ b/src/hierarchies/priors/abstract_prior_model.h @@ -22,7 +22,7 @@ //! given parameter state. //! //! We also store a pointer to the protobuf object that represents the type of -//! prior used fot the parameters in the likelihood. +//! prior used for the parameters in the likelihood. class AbstractPriorModel { public: diff --git a/src/hierarchies/priors/base_prior_model.h b/src/hierarchies/priors/base_prior_model.h index 0a3f19e16..366864b4f 100644 --- a/src/hierarchies/priors/base_prior_model.h +++ b/src/hierarchies/priors/base_prior_model.h @@ -27,7 +27,7 @@ //! @tparam State Class name of the container for state values //! @tparam HyperParams Class name of the container for hyperparameters //! @tparam Prior Class name of the protobuf message for the prior on the -//! hyperparameters. +//! hyperparameters. template class BasePriorModel : public AbstractPriorModel { diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index 9f82f8e60..ae58afc6e 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -10,8 +10,7 @@ #include "hyperparams.h" #include "src/utils/rng.h" -//! A conjugate prior model for the scalar linear regression likelihood, that -//! is +//! A conjugate prior model for the scalar linear regression likelihood, i.e. //! reg_coeffs | var ~ N_p(mu, sigsq * Lambda^-1) //! sigsq ~ IG(a,b) diff --git a/src/hierarchies/priors/nig_prior_model.h b/src/hierarchies/priors/nig_prior_model.h index fe517d350..682a24ac2 100644 --- a/src/hierarchies/priors/nig_prior_model.h +++ b/src/hierarchies/priors/nig_prior_model.h @@ -15,8 +15,8 @@ //! mu | var ~ N(mu0, var / lambda) //! var ~ IG(a,b) //! With several possibilies for hyper-priors on mu and var. We have considered -//! a normal prior for mu0 and a NGG for (mu0, a, b) in addition to fixing -//! prior hyperparameters. +//! a normal prior for mu0 and a Normal-Gamma-Gamma for (mu0, a, b) in addition +//! to fixing prior hyperparameters. class NIGPriorModel : public BasePriorModel Date: Wed, 11 May 2022 11:34:25 +0200 Subject: [PATCH 278/317] Notation fix --- src/hierarchies/priors/fa_prior_model.h | 2 +- src/hierarchies/priors/mnig_prior_model.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index 959a62a98..f35ac03fa 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -11,7 +11,7 @@ #include "src/utils/rng.h" //! A prior model for the factor analyzers likelihood, that is -//! mu_j ~ N(mutilde_j, psi) j=1,...,p +//! mu ~ N_p(mutilde, psi*I) //! Lambda ~ DL(alpha) //! Sigma = diag(sigsq_1,...,sigsq_p) //! sigsq_j ~ IG(a,b) j=1,...,p diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index ae58afc6e..885993db9 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -11,8 +11,8 @@ #include "src/utils/rng.h" //! A conjugate prior model for the scalar linear regression likelihood, i.e. -//! reg_coeffs | var ~ N_p(mu, sigsq * Lambda^-1) -//! sigsq ~ IG(a,b) +//! reg_coeffs | var ~ N_p(mu, var * Lambda^-1) +//! var ~ IG(a,b) class MNIGPriorModel : public BasePriorModel Date: Fri, 13 May 2022 10:56:52 +0200 Subject: [PATCH 279/317] Docs for hierarchies --- src/hierarchies/fa_hierarchy.h | 22 ++++++++++++++++ src/hierarchies/lapnig_hierarchy.h | 29 +++++++++++++-------- src/hierarchies/likelihoods/fa_likelihood.h | 4 +-- src/hierarchies/lin_reg_uni_hierarchy.h | 26 +++++++++++++++--- src/hierarchies/nnig_hierarchy.h | 25 +++++++++++++++--- src/hierarchies/nnw_hierarchy.h | 28 ++++++++++++++++++++ src/hierarchies/nnxig_hierarchy.h | 17 ++++++++++++ 7 files changed, 131 insertions(+), 20 deletions(-) diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index 15d278023..ac6c6ad16 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -8,18 +8,37 @@ #include "src/utils/distributions.h" #include "updaters/fa_updater.h" +//! Mixture of Factor Analysers hierarchy for multivariate data. +//! +//! This class represents a hierarchical model where data are distributed +//! according to a multivariate Normal likelihood with a specific factorization +//! of the covariance function (see the `FAHierarchy` class for details). The +//! likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma +//! centering distribution (see the `FAPriorModel` class for details). That is: +//! f(x_i|mu,Sigma,Lambda) = N(mu,Sigma+Lambda*Lambda^T) +//! mu ~ N(mu0,psi*I) +//! Lambda ~ DL(alpha) +//! Sigma = diag(sig1^2,...,sigp^2) +//! sigj^2 ~ IG(a,b) for j=1,...,p +//! where Lambda is the latent score matrix (size p x d with d << p) and +//! DL(alpha) is the Laplace-Dirichlet distribution. +//! See Bhattacharya et al. (2015) for further details. + class FAHierarchy : public BaseHierarchy { public: FAHierarchy() = default; ~FAHierarchy() = default; + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::FA; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Initialize likelihood dimension to prior one like->set_dim(prior->get_dim()); @@ -37,6 +56,9 @@ class FAHierarchy } }; +//! Empirical-Bayes hyperparameters initialization for the FA HIerarchy. +//! Sets the hyperparameters in `hier` starting from the data on which the user +//! wants to fit the model. inline void set_fa_hyperparams_from_data(FAHierarchy* hier) { auto dataset_ptr = std::static_pointer_cast(hier->get_likelihood()) diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h index a7f049115..000b60bd5 100644 --- a/src/hierarchies/lapnig_hierarchy.h +++ b/src/hierarchies/lapnig_hierarchy.h @@ -1,22 +1,26 @@ #ifndef BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ -// #include - -// #include -// #include -// #include - -// #include "algorithm_state.pb.h" -// #include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -// #include "hierarchy_prior.pb.h" - #include "base_hierarchy.h" +#include "hierarchy_id.pb.h" #include "likelihoods/laplace_likelihood.h" #include "priors/nxig_prior_model.h" #include "updaters/mala_updater.h" +//! Laplace Normal-InverseGamma hierarchy for univariate data. + +//! This class represents a hierarchical model where data are distributed +//! according to a laplace likelihood (see the `LaplaceLikelihood` class for +//! deatils).The likelihood parameters have a Normal x InverseGamma centering +//! distribution (see the `NxIGPriorModel` class for details). That is: +//! f(x_i|mu,lambda) = Laplace(mu,sqrt(var/2)) +//! mu ~ N(mu0,sig0^2) +//! var ~ IG(alpha0,beta0) +//! The state is composed of mean and variance (thus the scale for the Laplace +//! distribution is sqrt(var / 2)). The state hyperparameters are (mu_0, +//! sig0^2, alpha0, beta0), all scalar values. Note that this hierarchy is NOT +//! conjugate, thus the marginal distribution is not available in closed form. + class LapNIGHierarchy : public BaseHierarchy { @@ -24,12 +28,15 @@ class LapNIGHierarchy LapNIGHierarchy() = default; ~LapNIGHierarchy() = default; + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::LapNIG; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Get hypers auto hypers = prior->get_hypers(); diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h index 0c5e4938d..8a40bb6f1 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.h +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -13,8 +13,8 @@ #include "states/includes.h" //! A gaussian factor analytic likelihood, that is -//! Y ~ N_p(mu, Lambda * Lambda^T + Psi) -//! Where Lambda is a `p x d` matrix, usually d << p and `Psi` is a diagonal +//! Y ~ N_p(mu, Sigma + Lambda * Lambda^T) +//! Where Lambda is a `p x d` matrix, usually d << p and `Sigma` is a diagonal //! matrix. //! //! Parameters are stored in a `State::FA` state. diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index ae449a103..a55a51251 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -7,21 +7,37 @@ #include "priors/mnig_prior_model.h" #include "updaters/mnig_updater.h" +//! Linear regression hierarchy for univariate data. +//! +//! This class implements a dependent hierarchy which represents the classical +//! univariate Bayesian linear regression model, i.e.: +//! y_i | \beta, x_i, \sigma^2 \sim N(\beta^T x_i, sigma^2) +//! \beta | \sigma^2 \sim N(\mu, sigma^2 Lambda^{-1}) +//! \sigma^2 \sim InvGamma(a, b) +//! +//! The state consists of the `regression_coeffs` \beta, and the `var` sigma^2. +//! Lambda is called the variance-scaling factor. Note that this hierarchy is +//! conjugate, thus the marginal distribution is available in closed form. For +//! more information, please refer to parent classes: `BaseHierarchy`, +//! `UniLinRegLikelihood` for deatails on the likelihood model, and +//! `MNIGPriorModel` for details on the prior model. + class LinRegUniHierarchy : public BaseHierarchy { public: + LinRegUniHierarchy() = default; ~LinRegUniHierarchy() = default; - using BaseHierarchy::BaseHierarchy; - + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::LinRegUni; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Initialize likelihood dimension to prior one like->set_dim(prior->get_dim()); @@ -34,6 +50,10 @@ class LinRegUniHierarchy like->set_state(state); }; + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const override { auto params = hier_params->lin_reg_uni_state(); diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index b70e0eeac..36d14c94e 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -7,20 +7,33 @@ #include "priors/nig_prior_model.h" #include "updaters/nnig_updater.h" +//! Conjugate Normal Normal-InverseGamma hierarchy for univariate data. +//! +//! This class represents a hierarchical model where data are distributed +//! according to a Normal likelihood (see the `UniNormLikelihood` class for +//! details). The likelihood parameters have a Normal-InverseGamma centering +//! distribution (see the `NIGPriorModel` class for details). That is: +//! f(x_i|mu,sig^2) = N(mu,sig^2) +//! (mu,sig^2) ~ N-IG(mu0, lambda0, alpha0, beta0) +//! The state is composed of mean and variance. The state hyperparameters are +//! (mu_0, lambda0, alpha0, beta0), all scalar values. Note that this hierarchy +//! is conjugate, thus the marginal distribution is available in closed form. + class NNIGHierarchy : public BaseHierarchy { public: + NNIGHierarchy() = default; ~NNIGHierarchy() = default; - using BaseHierarchy::BaseHierarchy; - + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNIG; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Get hypers auto hypers = prior->get_hypers(); @@ -31,6 +44,10 @@ class NNIGHierarchy like->set_state(state); }; + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { auto params = hier_params->nnig_state(); @@ -41,4 +58,4 @@ class NNIGHierarchy } }; -#endif +#endif // BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 35e5920af..ffe13b645 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -8,18 +8,36 @@ #include "src/utils/distributions.h" #include "updaters/nnw_updater.h" +//! Normal Normal-Wishart hierarchy for multivariate data. + +//! This class represents a hierarchy, i.e. a cluster, whose multivariate data +//! are distributed according to a multivariate normal likelihood (see the +//! `MultiNormLikelihood` for details). The likelihood parameters have a +//! Normal-Wishart centering distribution (see the `NWPriorModel` class for +//! details). That is: f(x_i|mu,tau) = N(mu,tau^{-1}) +//! (mu,tau) ~ NW(mu0, lambda0, tau0, nu0) +//! The state is composed of mean and precision matrix. The Cholesky factor and +//! log-determinant of the latter are also included in the container for +//! efficiency reasons. The state's hyperparameters are (mu0, lambda0, tau0, +//! nu0), which are respectively vector, scalar, matrix, and scalar. Note that +//! this hierarchy is conjugate, thus the marginal distribution is available in +//! closed form. + class NNWHierarchy : public BaseHierarchy { public: NNWHierarchy() = default; ~NNWHierarchy() = default; + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNW; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Initialize likelihood dimension to prior one like->set_dim(prior->get_dim()); @@ -34,6 +52,10 @@ class NNWHierarchy like->set_state(state); }; + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { HyperParams pred_params = get_predictive_t_parameters(hier_params); @@ -44,6 +66,12 @@ class NNWHierarchy logdet); } + //! Helper function that computes the predictive parameters for the + //! multivariate t distribution from the current hyperparameter values. It is + //! used to efficiently compute the log-marginal distribution of data. + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @return A `HyperParam` object with the predictive parameters HyperParams get_predictive_t_parameters(ProtoHypersPtr hier_params) const { auto params = hier_params->nnw_state(); // Compute dof and scale of marginal distribution diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h index a0de895fe..1901083c6 100644 --- a/src/hierarchies/nnxig_hierarchy.h +++ b/src/hierarchies/nnxig_hierarchy.h @@ -7,18 +7,35 @@ #include "priors/nxig_prior_model.h" #include "updaters/nnxig_updater.h" +//! Semi-conjugate Normal Normal x InverseGamma hierarchy for univariate data. +//! +//! This class represents a hierarchical model where data are distributed +//! according to a Normal likelihood (see the `UniNormLikelihood` class for +//! details). The likelihood parameters have a Normal x InverseGamma centering +//! distribution (see the `NxIGPriorModel` class for details). That is: +//! f(x_i|mu,sig^2) = N(mu,sig^2) +//! mu ~ N(mu0, sig0^2) +//! sig^2 ~ IG(alpha0, beta0) +//! The state is composed of mean and variance. The state hyperparameters are +//! (mu_0, sig0^2, alpha0, beta0), all scalar values. Note that this hierarchy +//! is NOT conjugate, meaning that the marginal distribution is not available +//! in closed form. + class NNxIGHierarchy : public BaseHierarchy { public: NNxIGHierarchy() = default; ~NNxIGHierarchy() = default; + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNxIG; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Get hypers auto hypers = prior->get_hypers(); From bff54a1e47793c996766d092bfb872ac4c7660e7 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 10:25:46 +0200 Subject: [PATCH 280/317] Improve docs --- src/hierarchies/fa_hierarchy.h | 2 +- src/hierarchies/lin_reg_uni_hierarchy.h | 41 +++++++++++++++---------- src/hierarchies/nnw_hierarchy.h | 9 +++--- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index ac6c6ad16..f98dbc7c5 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -12,7 +12,7 @@ //! //! This class represents a hierarchical model where data are distributed //! according to a multivariate Normal likelihood with a specific factorization -//! of the covariance function (see the `FAHierarchy` class for details). The +//! of the covariance matrix (see the `FAHierarchy` class for details). The //! likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma //! centering distribution (see the `FAPriorModel` class for details). That is: //! f(x_i|mu,Sigma,Lambda) = N(mu,Sigma+Lambda*Lambda^T) diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index a55a51251..50d4239c2 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -7,20 +7,25 @@ #include "priors/mnig_prior_model.h" #include "updaters/mnig_updater.h" -//! Linear regression hierarchy for univariate data. -//! -//! This class implements a dependent hierarchy which represents the classical -//! univariate Bayesian linear regression model, i.e.: -//! y_i | \beta, x_i, \sigma^2 \sim N(\beta^T x_i, sigma^2) -//! \beta | \sigma^2 \sim N(\mu, sigma^2 Lambda^{-1}) -//! \sigma^2 \sim InvGamma(a, b) -//! -//! The state consists of the `regression_coeffs` \beta, and the `var` sigma^2. -//! Lambda is called the variance-scaling factor. Note that this hierarchy is -//! conjugate, thus the marginal distribution is available in closed form. For -//! more information, please refer to parent classes: `BaseHierarchy`, -//! `UniLinRegLikelihood` for deatails on the likelihood model, and -//! `MNIGPriorModel` for details on the prior model. +/** + * Linear regression hierarchy for univariate data. + * + * This class implements a dependent hierarchy which represents the classical + * univariate Bayesian linear regression model, i.e.: + * + * \f[ + * y_i \mid \beta, x_i, \sigma^2 &\sim N(\beta^T x_i, \sigma^2) \\ + * \beta \mid \sigma^2 &\sim N(\mu, \sigma^2 \Lambda^{-1}) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * The state consists of the `regression_coeffs` \f$ \beta \f$, and the `var` + * \f$ \sigma^2 \f$. Lambda is called the variance-scaling factor. Note that + * this hierarchy is conjugate, thus the marginal distribution is available in + * closed form. For more information, please refer to the parent class + * `BaseHierarchy`, to the class `UniLinRegLikelihood` for details on the + * likelihood model and to `MNIGPriorModel` for details on the prior model. + */ class LinRegUniHierarchy : public BaseHierarchylin_reg_uni_state(); diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index ffe13b645..cb0744531 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -10,7 +10,7 @@ //! Normal Normal-Wishart hierarchy for multivariate data. -//! This class represents a hierarchy, i.e. a cluster, whose multivariate data +//! This class represents a hierarchy whose multivariate data //! are distributed according to a multivariate normal likelihood (see the //! `MultiNormLikelihood` for details). The likelihood parameters have a //! Normal-Wishart centering distribution (see the `NWPriorModel` class for @@ -53,9 +53,10 @@ class NNWHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { HyperParams pred_params = get_predictive_t_parameters(hier_params); From 2b4247e4126c2672984eef8d7573f84e05beb98d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:31:56 +0200 Subject: [PATCH 281/317] Fix doxygen warnings --- src/algorithms/marginal_algorithm.h | 9 ++++++--- src/mixings/dirichlet_mixing.h | 1 + src/mixings/mixture_finite_mixing.h | 1 + src/mixings/pityor_mixing.h | 1 + src/runtime/factory.h | 2 +- 5 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/algorithms/marginal_algorithm.h b/src/algorithms/marginal_algorithm.h index 9b1aac89d..34af2167a 100644 --- a/src/algorithms/marginal_algorithm.h +++ b/src/algorithms/marginal_algorithm.h @@ -45,9 +45,12 @@ class MarginalAlgorithm : public BaseAlgorithm { protected: //! Computes marginal contribution of the given cluster to the lpdf estimate - //! @param hier Pointer to the `Hierarchy` object representing the cluster - //! @param grid Grid of row points on which the density is to be evaluated - //! @return The marginal component of the estimate + //! @param hier Pointer to the `Hierarchy` object representing the + //! cluster + //! @param grid Grid of row points on which the density is to be + //! evaluated + //! @param covariate (Optional) covariate vectors associated to data + //! @return The marginal component of the estimate virtual Eigen::VectorXd lpdf_marginal_component( const std::shared_ptr hier, const Eigen::MatrixXd &grid, diff --git a/src/mixings/dirichlet_mixing.h b/src/mixings/dirichlet_mixing.h index 27e9416c4..bc31c5cd2 100644 --- a/src/mixings/dirichlet_mixing.h +++ b/src/mixings/dirichlet_mixing.h @@ -60,6 +60,7 @@ class DirichletMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/mixings/mixture_finite_mixing.h b/src/mixings/mixture_finite_mixing.h index 3b85a4d10..c1adc3dfa 100644 --- a/src/mixings/mixture_finite_mixing.h +++ b/src/mixings/mixture_finite_mixing.h @@ -70,6 +70,7 @@ class MixtureFiniteMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/mixings/pityor_mixing.h b/src/mixings/pityor_mixing.h index 67d9218ac..5ebec6958 100644 --- a/src/mixings/pityor_mixing.h +++ b/src/mixings/pityor_mixing.h @@ -62,6 +62,7 @@ class PitYorMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/runtime/factory.h b/src/runtime/factory.h index f496bd42f..8a4bbba31 100644 --- a/src/runtime/factory.h +++ b/src/runtime/factory.h @@ -77,7 +77,7 @@ class Factory { //! Adds a builder function to the storage //! @param id Identifier to associate the builder with - //! @param bulider Builder function for a specific object type + //! @param builder Builder function for a specific object type void add_builder(const Identifier &id, const Builder &builder) { storage.insert(std::make_pair(id, builder)); } From ad97f1b39947c70f78f17afe6fb212e399b0943d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:32:55 +0200 Subject: [PATCH 282/317] Fix utils.rst bad rendering --- src/utils/cluster_utils.h | 7 ++- src/utils/distributions.h | 124 ++++++++++++++++++++------------------ src/utils/eigen_utils.h | 7 ++- src/utils/io_utils.h | 5 +- src/utils/proto_utils.h | 11 ++-- src/utils/rng.h | 3 +- 6 files changed, 84 insertions(+), 73 deletions(-) diff --git a/src/utils/cluster_utils.h b/src/utils/cluster_utils.h index 0fe166399..1930bab16 100644 --- a/src/utils/cluster_utils.h +++ b/src/utils/cluster_utils.h @@ -3,15 +3,18 @@ #include -//! This file includes some utilities for cluster estimation. These functions -//! only use Eigen ojects. +//! \file cluster_utils.h +//! The `cluster_utils.h` file includes some utilities for cluster estimation. +//! These functions only use Eigen objects. namespace bayesmix { + //! Computes the posterior similarity matrix the data Eigen::MatrixXd posterior_similarity(const Eigen::MatrixXd &alloc_chain); //! Estimates the clustering structure of the data via LS minimization Eigen::VectorXi cluster_estimate(const Eigen::MatrixXi &alloc_chain); + } // namespace bayesmix #endif // BAYESMIX_UTILS_CLUSTER_UTILS_H_ diff --git a/src/utils/distributions.h b/src/utils/distributions.h index 7fcd2bbe0..d6b7a0635 100644 --- a/src/utils/distributions.h +++ b/src/utils/distributions.h @@ -7,14 +7,15 @@ #include "algorithm_state.pb.h" -//! This file includes several useful functions related to probability -//! distributions, including categorical variables, popular multivariate -//! distributions, and distribution distances. Some of these functions make use -//! of OpenMP parallelism to achieve better efficiency. +//! @file distributions.h +//! The `distributions.h` file includes several useful functions related to +//! probability distributions, including categorical variables, popular +//! multivariate distributions, and distribution distances. Some of these +//! functions make use of OpenMP parallelism to achieve better efficiency. namespace bayesmix { -/* +/** * Returns a pseudorandom categorical random variable on the set * {start, ..., start + k} where k is the size of the given probability vector * @@ -26,64 +27,66 @@ namespace bayesmix { int categorical_rng(const Eigen::VectorXd &probas, std::mt19937_64 &rng, const int start = 0); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution parametrized by mean and precision matrix on a single point * - * @param datum Point in which to evaluate the the lpdf - * @param mean The mean of the Gaussian distribution - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @prec_logdet The logarithm of the determinant of the precision matrix - * @return The evaluation of the lpdf + * @param datum Point in which to evaluate the the lpdf + * @param mean The mean of the Gaussian distribution + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param prec_logdet The logarithm of the determinant of the precision + * matrix + * @return The evaluation of the lpdf */ double multi_normal_prec_lpdf(const Eigen::VectorXd &datum, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, const double prec_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution parametrized by mean and precision matrix on multiple points * - * @param data Grid of points (by row) on which to evaluate the lpdf - * @param mean The mean of the Gaussian distribution - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @prec_logdet The logarithm of the determinant of the precision matrix - * @return The evaluation of the lpdf + * @param data Grid of points (by row) on which to evaluate the lpdf + * @param mean The mean of the Gaussian distribution + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param prec_logdet The logarithm of the determinant of the precision + * matrix + * @return The evaluation of the lpdf */ Eigen::VectorXd multi_normal_prec_lpdf_grid(const Eigen::MatrixXd &data, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, const double prec_logdet); -/* +/** * Returns a pseudorandom multivariate normal random variable with diagonal * covariance matrix * - * @param mean The mean of the Gaussian r.v. - * @param cov_diag The diagonal covariance matrix - * @rng Random number generator - * @return multivariate normal r.v. + * @param mean The mean of the Gaussian r.v. + * @param cov_diag The diagonal covariance matrix + * @param rng Random number generator + * @return Multivariate normal r.v. */ Eigen::VectorXd multi_normal_diag_rng( const Eigen::VectorXd &mean, const Eigen::DiagonalMatrix &cov_diag, std::mt19937_64 &rng); -/* +/** * Returns a pseudorandom multivariate normal random variable parametrized * through mean and Cholesky decomposition of precision matrix * - * @param mean The mean of the Gaussian r.v. - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @param rng Random number generator - * @return multivariate normal r.v. + * @param mean The mean of the Gaussian r.v. + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param rng Random number generator + * @return Multivariate normal r.v. */ Eigen::VectorXd multi_normal_prec_chol_rng( const Eigen::VectorXd &mean, const Eigen::LLT &prec_chol, std::mt19937_64 &rng); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution with the following covariance structure: * Sigma + Lambda * Lambda^T @@ -91,7 +94,6 @@ Eigen::VectorXd multi_normal_prec_chol_rng( * y^T*(Sigma + Lambda * Lambda^T)^{-1}*y = y^T*Sigma^{-1}*y - * ||wood_factor*y||^2 * - * * @param datum Point on which to evaluate the lpdf * @param mean The mean of the Gaussian distribution * @param sigma_diag_inverse The inverse of the diagonal of Sigma matrix @@ -106,7 +108,7 @@ double multi_normal_lpdf_woodbury_chol( const Eigen::DiagonalMatrix &sigma_diag_inverse, const Eigen::MatrixXd &wood_factor, const double &cov_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution with the following covariance structure: * Sigma + Lambda * Lambda^T @@ -116,41 +118,42 @@ double multi_normal_lpdf_woodbury_chol( * computation from being O(p^3) to being O(d^3 p) which gives a substantial * speedup when p >> d * - * @param datum Point on which to evaluate the lpdf - * @param mean The mean of the Gaussian distribution + * @param datum Point on which to evaluate the lpdf + * @param mean The mean of the Gaussian distribution * @param sigma_diag The diagonal of Sigma matrix * @param lambda Rectangular matrix in Woodbury Identity - * @return The evaluation of the lpdf + * @return The evaluation of the lpdf */ double multi_normal_lpdf_woodbury(const Eigen::VectorXd &datum, const Eigen::VectorXd &mean, const Eigen::VectorXd &sigma_diag, const Eigen::MatrixXd &lambda); -/* - * Returns the log-determinant of the matrix Lambda Lambda^T + Sigma +/** + * Returns the log-determinant of the matrix \f$ \Lambda\Lambda^T + \Sigma \f$ * and the 'wood_factor', i.e. * L^{-1} * Lambda^T * Sigma^{-1}, * where L is the (lower) Cholesky factor of * I + Lambda^T * Sigma^{-1} * Lambda * - * @param sigma_dag_inverse The inverse of the diagonal matrix Sigma + * @param sigma_diag_inverse The inverse of the diagonal matrix Sigma * @param lambda The matrix Lambda */ std::pair compute_wood_chol_and_logdet( const Eigen::DiagonalMatrix &sigma_diag_inverse, const Eigen::MatrixXd &lambda); -/* +/** * Evaluates the log probability density function of a multivariate Student's t * distribution on a single point * - * @param datum Point in which to evaluate the the lpdf - * @param df The degrees of freedom of the Student's t distribution - * @param mean The mean of the Student's t distribution - * @invscale_chol The (lower) Cholesky factor of the inverse scale matrix - * @prec_logdet The logarithm of the determinant of the inverse scale matrix - * @return The evaluation of the lpdf + * @param datum Point in which to evaluate the the lpdf + * @param df The degrees of freedom of the Student's t distribution + * @param mean The mean of the Student's t distribution + * @param invscale_chol The (lower) Cholesky factor of the inverse scale matrix + * @param scale_logdet The logarithm of the determinant of the inverse scale + * matrix + * @return The evaluation of the lpdf */ double multi_student_t_invscale_lpdf(const Eigen::VectorXd &datum, const double df, @@ -158,28 +161,29 @@ double multi_student_t_invscale_lpdf(const Eigen::VectorXd &datum, const Eigen::MatrixXd &invscale_chol, const double scale_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Student's t * distribution on multiple points * - * @param data Grid of points (by row) on which to evaluate the lpdf - * @param df The degrees of freedom of the Student's t distribution - * @param mean The mean of the Student's t distribution - * @invscale_chol The (lower) Cholesky factor of the inverse scale matrix - * @prec_logdet The logarithm of the determinant of the inverse scale matrix - * @return The evaluation of the lpdf + * @param data Grid of points (by row) on which to evaluate the lpdf + * @param df The degrees of freedom of the Student's t distribution + * @param mean The mean of the Student's t distribution + * @param invscale_chol The (lower) Cholesky factor of the inverse scale matrix + * @param scale_logdet The logarithm of the determinant of the inverse scale + * matrix + * @return The evaluation of the lpdf */ Eigen::VectorXd multi_student_t_invscale_lpdf_grid( const Eigen::MatrixXd &data, const double df, const Eigen::VectorXd &mean, const Eigen::MatrixXd &invscale_chol, const double scale_logdet); -/* +/** * Computes the L^2 distance between the univariate mixture of Gaussian - * densities p1(x) = \sum_{h=1}^m1 w1[h] N(x | mean1[h], var1[h]) and - * p2(x) = \sum_{h=1}^m2 w2[h] N(x | mean2[h], var2[h]) + * densities p1(x) = sum_{h=1}^m1 w1[h] N(x | mean1[h], var1[h]) and + * p2(x) = sum_{h=1}^m2 w2[h] N(x | mean2[h], var2[h]) * * The L^2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} */ double gaussian_mixture_dist(const Eigen::VectorXd &means1, const Eigen::VectorXd &vars1, @@ -188,13 +192,13 @@ double gaussian_mixture_dist(const Eigen::VectorXd &means1, const Eigen::VectorXd &vars2, const Eigen::VectorXd &weights2); -/* +/** * Computes the L^2 distance between the multivariate mixture of Gaussian - * densities p1(x) = \sum_{h=1}^m1 w1[h] N(x | mean1[h], Prec[1]^{-1}) and - * p2(x) = \sum_{h=1}^m2 w2[h] N(x | mean2[h], Prec2[h]^{-1}) + * densities p1(x) = sum_{h=1}^m1 w1[h] N(x | mean1[h], Prec[1]^{-1}) and + * p2(x) = sum_{h=1}^m2 w2[h] N(x | mean2[h], Prec2[h]^{-1}) * * The L^2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} */ double gaussian_mixture_dist(const std::vector &means1, const std::vector &precs1, @@ -203,11 +207,11 @@ double gaussian_mixture_dist(const std::vector &means1, const std::vector &precs2, const Eigen::VectorXd &weights2); -/* +/** * Computes the L^2 distance between the mixture of Gaussian * densities p(x) and q(x). These could be either univariate or multivariate. * The L2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} * * @param clus1, clus2 Cluster-specific parameters of the mix. densities * @param weights1, weights2 Weigths of the mixture densities diff --git a/src/utils/eigen_utils.h b/src/utils/eigen_utils.h index 463cb3979..2307b7f6f 100644 --- a/src/utils/eigen_utils.h +++ b/src/utils/eigen_utils.h @@ -4,9 +4,10 @@ #include #include -//! This file implements a few methods to manipulate groups of matrices, mainly -//! by joining different objects, as well as additional utilities for SPD -//! checking and grid creation. +//! @file eigen_utils.h +//! The `eigen_utils.h` file implements a few methods to manipulate groups of +//! matrices, mainly by joining different objects, as well as additional +//! utilities for SPD checking and grid creation. namespace bayesmix { //! Concatenates a vector of Eigen matrices along the rows diff --git a/src/utils/io_utils.h b/src/utils/io_utils.h index 89b830e34..b9c4231a6 100644 --- a/src/utils/io_utils.h +++ b/src/utils/io_utils.h @@ -3,8 +3,9 @@ #include -//! This file implements basic input-output utilities for Eigen matrices from -//! and to text files. +//! @file io_utils.h +//! The `io_utils.h` file implements basic input-output utilities for Eigen +//! matrices from and to text files. namespace bayesmix { //! Checks whether the given file is available for writing diff --git a/src/utils/proto_utils.h b/src/utils/proto_utils.h index cd5466d0a..cb8c3333d 100644 --- a/src/utils/proto_utils.h +++ b/src/utils/proto_utils.h @@ -5,11 +5,12 @@ #include "matrix.pb.h" -//! This file implements a few useful functions to manipulate Protobuf objects. -//! For instance, this library implements its own version of vectors and -//! matrices, and the functions implemented here convert from these types to -//! the Eigen ones and viceversa. One can also read a Protobuf from a text -//! file. This is mostly useful for algorithm configuration files. +//! @file proto_utils.h +//! The `proto_utils.h` file implements a few useful functions to manipulate +//! Protobuf objects. For instance, this library implements its own version of +//! vectors and matrices, and the functions implemented here convert from these +//! types to the Eigen ones and viceversa. One can also read a Protobuf from a +//! text file. This is mostly useful for algorithm configuration files. namespace bayesmix { diff --git a/src/utils/rng.h b/src/utils/rng.h index 8fc6e6264..fc71a2f69 100644 --- a/src/utils/rng.h +++ b/src/utils/rng.h @@ -3,7 +3,8 @@ #include -//! Simple Random Number Generation class wrapper. +//! @file rng.h +//! The `rng.h` file defines a simple Random Number Generation class wrapper. //! This class wraps the C++ standard RNG object and allows the use of any RNG //! seed. It is implemented as a singleton, so that every object used in the //! library has access to the same exact RNG engine. From cadcd600bd04e4ab1a9e402a5750eaf08d19da24 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:33:59 +0200 Subject: [PATCH 283/317] Uniform docs for likelihoods (+ latex in doxygen) --- src/hierarchies/likelihoods/fa_likelihood.h | 25 +++++++++------- .../likelihoods/laplace_likelihood.h | 26 +++++++++------- .../likelihoods/multi_norm_likelihood.h | 20 ++++++++----- .../likelihoods/uni_lin_reg_likelihood.h | 30 ++++++++++--------- .../likelihoods/uni_norm_likelihood.h | 20 ++++++++----- 5 files changed, 70 insertions(+), 51 deletions(-) diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h index 8a40bb6f1..3e2e08e40 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.h +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -12,16 +12,21 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A gaussian factor analytic likelihood, that is -//! Y ~ N_p(mu, Sigma + Lambda * Lambda^T) -//! Where Lambda is a `p x d` matrix, usually d << p and `Sigma` is a diagonal -//! matrix. -//! -//! Parameters are stored in a `State::FA` state. -//! We store as summary statistics the sum of the y_i's, but it is -//! not sufficient for all the updates involved. Therefore, all the -//! observations allocated to a cluster are processed when computing the -//! cluster lpdf. +/** + * A gaussian factor analytic likelihood, using the `State::FA` state. + * Represents the model: + * + * \f[ + * \bm{y}_1,\dots,\bm{y}_k \stackrel{\small\mathrm{iid}}{\sim} N_p(\bm{\mu}, + * \Sigma + \Lambda\Lambda^T), \f] + * + * where Lambda is a \f$ p \times d \f$ matrix, usually \f$ d << p \f$ and \f$ + * \Sigma \f$ is a diagonal matrix. Parameters are stored in a `State::FA` + * state. We store as summary statistics the sum of the \f$ \bm{y}_i \f$'s, but + * it is not sufficient for all the updates involved. Therefore, all the + * observations allocated to a cluster are processed when computing the + * cluster lpdf. + */ class FALikelihood : public BaseLikelihood { public: diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 6b9196b8a..9d4c25128 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -11,17 +11,21 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A univariate laplace likelihood -//! -//! Represents the model: -//! y_i ~ Laplace(mu, var) -//! where mu is the mean and center of the distribution -//! and var is the variance. The scale is then sqrt(var / 2) -//! These parameters are stored in a `State::UniLS` state -//! -//! Since the Laplace likelihood does not have sufficient statistics -//! other than the whole sample, the `update_sum_stats` method -//! does nothing. +/** + * A univariate Laplace likelihood, using the `State::UniLS` state. Represents + * the model: + * + * \f[ + * y_1,\dots,y_k \mid \mu, \sigma^2 \stackrel{\small\mathrm{iid}}{\sim} + * Laplace(\mu,\sigma^2), \f] + * + * where \f$ \mu \f$ is the mean and center of the distribution + * and \f$ \sigma^2 \f$ is the variance. The scale parameter \f$ \lambda \f$ is + * then \f$ \sqrt{\sigma^2/2} \f$. These parameters are stored in a + * `State::UniLS` state. Since the Laplace likelihood does not have sufficient + * statistics other than the whole sample, the `update_sum_stats()` method does + * nothing. + */ class LaplaceLikelihood : public BaseLikelihood { diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index 13f338faa..6e249a75e 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -12,14 +12,18 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A multivariate normal likelihood -//! -//! Represents the model: -//! y_1, ..., y_m ~ N(mu, Cov) -//! where (mu, Cov) are stored in a `State::MultiLS` state -//! -//! The sufficient statistics stored are the sum of the y_i's -//! and the sum of y_i^T y_i. +/** + * A multivariate normal likelihood, using the `State::MultiLS` state. + * Represents the model: + * + * \f[ + * \bm{y}_1,\dots, \bm{y}_k \stackrel{\small\mathrm{iid}}{\sim} + * N_p(\bm{\mu}, \Sigma), \f] + * + * where \f$ (\bm{\mu}, \Sigma) \f$ are stored in a `State::MultiLS` state. + * The sufficient statistics stored are the sum of the \f$ \bm{y}_i \f$'s + * and the sum of \f$ \bm{y}_i^T \bm{y}_i \f$. + */ class MultiNormLikelihood : public BaseLikelihood { diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h index 96600d4dc..bb9e55687 100644 --- a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h @@ -11,16 +11,18 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A scalar linear regression model -//! -//! Represents the model: -//! y_i ~ N(x_i^T * reg_coeffs, var) -//! where (reg_coeffs, var) are stored in a `State::UniLinRegLS` state -//! -//! The sufficient statistics stored are the -//! 1) sum of y_i^2 -//! 2) sum of x_i^T x_i -//! 3) sum of y_i x_i^T +/** + * A scalar linear regression model, using the `State::UniLinRegLS` state. + * Represents the model: + * + * \f[ + * y_i \mid \bm{x}_i, \bm{\beta}, \sigma^2 + * \stackrel{\small\mathrm{ind}}{\sim} N(\bm{x}_i^T\bm{\beta},\sigma^2), \f] + * + * where \f$ (\bm{\beta}, \sigma^2) \f$ are stored in a `State::UniLinRegLS` + * state. The sufficient statistics stored are the sum of \f$ y_i^2 \f$, the + * sum of \f$ \bm{x}_i^T \bm{x}_i \f$ and the sum of \f$ y_i \bm{x}_i^T \f$. + */ class UniLinRegLikelihood : public BaseLikelihood { @@ -48,13 +50,13 @@ class UniLinRegLikelihood const Eigen::RowVectorXd &covariate, bool add) override; - //! Dimension of the coefficients vector + // Dimension of the coefficients vector unsigned int dim; - //! Represents pieces of y^t y + // Represents pieces of y^t y double data_sum_squares; - //! Represents pieces of X^T X + // Represents pieces of X^T X Eigen::MatrixXd covar_sum_squares; - //! Represents pieces of X^t y + // Represents pieces of X^t y Eigen::VectorXd mixed_prod; }; diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index eb17f6853..e278a3635 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -11,14 +11,18 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A univariate normal likelihood, using the `State::UniLS` state. -//! -//! Represents the model: -//! y_1, ..., y_m ~ N(mu, var) -//! where (mu, var) are stored in a `State::UniLS` state -//! -//! The sufficient statistics stored are the sum of the y_i's -//! and the sum of y_i^2. +/** + * A univariate normal likelihood, using the `State::UniLS` state. Represents + * the model: + * + * \f[ + * y_1, \dots, y_k \mid \mu, \sigma^2 \stackrel{\small\mathrm{iid}}{\sim} + * N(\mu, \sigma^2), \f] + * + * where \f$ (\mu, \sigma^2) \f$ are stored in a `State::UniLS` state. + * The sufficient statistics stored are the sum of the \f$ y_i \f$'s and the + * sum of \f$ y_i^2 \f$. + */ class UniNormLikelihood : public BaseLikelihood { From a33e028ed35cc9dcd01ff230ddccc64f99ffdd88 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:34:44 +0200 Subject: [PATCH 284/317] Fix doxygen warnings --- src/hierarchies/base_hierarchy.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 3206c4967..ad71a0c93 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -343,9 +343,10 @@ class BaseHierarchy : public AbstractHierarchy { virtual void initialize_state() = 0; //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf virtual double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const { if (!is_conjugate()) { @@ -358,10 +359,11 @@ class BaseHierarchy : public AbstractHierarchy { } //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf virtual double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { From 88534406040831a41e9ec8b5ffbc76aeae61840c Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:35:28 +0200 Subject: [PATCH 285/317] Add latex to doxygen (ONGOING) --- src/hierarchies/abstract_hierarchy.h | 74 ++++++++++++++----------- src/hierarchies/lin_reg_uni_hierarchy.h | 8 +-- src/hierarchies/nnig_hierarchy.h | 7 ++- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 5bd4387e0..947f362cf 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -16,38 +16,48 @@ #include "src/hierarchies/updaters/abstract_updater.h" #include "src/utils/rng.h" -//! Abstract base class for a hierarchy object. -//! This class is the basis for a curiously recurring template pattern (CRTP) -//! for `Hierarchy` objects, and is solely composed of interface functions for -//! derived classes to use. For more information about this pattern, as well -//! the list of methods required for classes in this inheritance tree, please -//! refer to the README.md file included in this folder. - -//! This abstract class represents a Bayesian hierarchical model: -//! x_1, ..., x_n \sim f(x | \theta) -//! theta \sim G -//! A Hierarchy object can compute the following quantities: -//! 1- the likelihood log-probability density function -//! 2- the prior predictive probability: \int_\Theta f(x | theta) G(d\theta) -//! (for conjugate models only) -//! 3- the posterior predictive probability -//! \int_\Theta f(x | theta) G(d\theta | x_1, ..., x_n) -//! (for conjugate models only) -//! Moreover, the Hierarchy knows how to sample from the full conditional of -//! theta, possibly in an approximate way. -//! -//! In the context of our Gibbs samplers, an hierarchy represents the parameter -//! value associated to a certain cluster, and also knows which observations -//! are allocated to that cluster. -//! Moreover, hyperparameters and (possibly) hyperpriors associated to them can -//! be shared across multiple Hierarchies objects via a shared pointer. -//! In conjunction with a single `Mixing` object, a collection of `Hierarchy` -//! objects completely defines a mixture model, and these two parts can be -//! chosen independently of each other. -//! Communication with other classes, as well as storage of some relevant -//! values, is performed via appropriately defined Protobuf messages (see for -//! instance the proto/ls_state.proto and proto/hierarchy_prior.proto files) -//! and their relative class methods. +/** + * Abstract base class for a hierarchy object. + * This class is the basis for a curiously recurring template pattern (CRTP) + * for `Hierarchy` objects, and is solely composed of interface functions for + * derived classes to use. For more information about this pattern, as well + * the list of methods required for classes in this inheritance tree, please + * refer to the README.md file included in this folder. + * + * This abstract class represents a Bayesian hierarchical model: + * + * \f[ + * x_1,\dots,x_n &\sim f(x \mid \theta) \\ + * \theta &\sim G + * \f] + * + * A Hierarchy object can compute the following quantities: + * + * 1. the likelihood log-probability density function + * 2. the prior predictive probability: \f$ \int_\Theta f(x \mid \theta) + * G(d\theta) \f$ (for conjugate models only) + * 3. the posterior predictive probability + * \f$ \int_\Theta f(x \mid \theta) G(d\theta \mid x_1, ..., x_n) \f$ + * (for conjugate models only) + * + * Moreover, the Hierarchy knows how to sample from the full conditional of + * \f$ \theta \f$, possibly in an approximate way. + * + * In the context of our Gibbs samplers, an hierarchy represents the parameter + * value associated to a certain cluster, and also knows which observations + * are allocated to that cluster. + * + * Moreover, hyperparameters and (possibly) hyperpriors associated to them can + * be shared across multiple Hierarchies objects via a shared pointer. + * In conjunction with a single `Mixing` object, a collection of `Hierarchy` + * objects completely defines a mixture model, and these two parts can be + * chosen independently of each other. + * + * Communication with other classes, as well as storage of some relevant + * values, is performed via appropriately defined Protobuf messages (see for + * instance the `proto/ls_state.proto` and `proto/hierarchy_prior.proto` files) + * and their relative class methods. + */ class AbstractHierarchy { public: diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index 50d4239c2..bd208ef01 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -20,10 +20,10 @@ * \f] * * The state consists of the `regression_coeffs` \f$ \beta \f$, and the `var` - * \f$ \sigma^2 \f$. Lambda is called the variance-scaling factor. Note that - * this hierarchy is conjugate, thus the marginal distribution is available in - * closed form. For more information, please refer to the parent class - * `BaseHierarchy`, to the class `UniLinRegLikelihood` for details on the + * \f$ \sigma^2 \f$. \f$ \Lambda \f$ is called the variance-scaling factor. + * Note that this hierarchy is conjugate, thus the marginal distribution is + * available in closed form. For more information, please refer to the parent + * class `BaseHierarchy`, to the class `UniLinRegLikelihood` for details on the * likelihood model and to `MNIGPriorModel` for details on the prior model. */ diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index 36d14c94e..4d1db6200 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -45,9 +45,10 @@ class NNIGHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { auto params = hier_params->nnig_state(); From 0bb77d24204f1ee817d20d03d91f38209efaa60d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:36:36 +0200 Subject: [PATCH 286/317] Improved rst files (ONGOING) --- docs/Doxyfile.in | 4 +-- docs/hierarchies.rst | 79 ++++++++++++++++++++++++---------------- docs/index.rst | 11 ++++-- docs/likelihoods.rst | 85 ++++++++++++++++++++++++++++++++++++++++++++ docs/utils.rst | 13 ++++--- 5 files changed, 152 insertions(+), 40 deletions(-) create mode 100644 docs/likelihoods.rst diff --git a/docs/Doxyfile.in b/docs/Doxyfile.in index 30cda29e9..6b163fa99 100644 --- a/docs/Doxyfile.in +++ b/docs/Doxyfile.in @@ -1819,7 +1819,7 @@ PAPER_TYPE = a4 # If left blank no extra packages will be included. # This tag requires that the tag GENERATE_LATEX is set to YES. -EXTRA_PACKAGES = +EXTRA_PACKAGES = bm # The LATEX_HEADER tag can be used to specify a personal LaTeX header for the # generated LaTeX document. The header should contain everything until the first @@ -1893,7 +1893,7 @@ USE_PDFLATEX = YES # The default value is: NO. # This tag requires that the tag GENERATE_LATEX is set to YES. -LATEX_BATCHMODE = NO +LATEX_BATCHMODE = YES # If the LATEX_HIDE_INDICES tag is set to YES then doxygen will not include the # index chapters (such as File Index, Compound Index, etc.) in the output. diff --git a/docs/hierarchies.rst b/docs/hierarchies.rst index 21c6fe1d7..7f520588a 100644 --- a/docs/hierarchies.rst +++ b/docs/hierarchies.rst @@ -6,24 +6,23 @@ Hierarchies In our algorithms, we store a vector of hierarchies, each of which represent a parameter :math:`\theta_h`. The hierarchy implements all the methods needed to update :math:`\theta_h`: sampling from the prior distribution :math:`P_0`, the full-conditional distribution (given the data {:math:`y_i` such that :math:`c_i = h`} ) and so on. - ------------------------- Main operations performed ------------------------- A hierarchy must be able to perform the following operations: -1. Sample from the prior distribution: generate :math:`\theta_h \sim P_0` [``sample_prior``] -2. Sample from the 'full conditional' distribution: generate theta_h from the distribution :math:`p(\theta_h \mid \cdots ) \propto P_0(\theta_h) \prod_{i: c_i = h} k(y_i | \theta_h)` [``sample_full_conditional``] -3. Update the hyperparameters involved in :math:`P_0` [``update_hypers``] -4. Evaluate the likelihood in one point, i.e. :math:`k(x | \theta_h)` for theta_h the current value of the parameters [``like_lpdf``] -5. When :math:`k` and :math:`P_0` are conjugate, we must also be able to compute the marginal/prior predictive distribution in one point, i.e. :math:`m(x) = \int k(x | \theta) P_0(d\theta)`, and the conditional predictive distribution :math:`m(x | \textbf{y} ) = \int k(x | \theta) P_0(d\theta | \{y_i: c_i = h\})` [``prior_pred_lpdf``, ``conditional_pred_lpdf``] +a. Sample from the prior distribution: generate :math:`\theta_h \sim P_0` [``sample_prior``] +b. Sample from the 'full conditional' distribution: generate theta_h from the distribution :math:`p(\theta_h \mid \cdots ) \propto P_0(\theta_h) \prod_{i: c_i = h} k(y_i | \theta_h)` [``sample_full_conditional``] +c. Update the hyperparameters involved in :math:`P_0` [``update_hypers``] +d. Evaluate the likelihood in one point, i.e. :math:`k(x | \theta_h)` for theta_h the current value of the parameters [``like_lpdf``] +e. When :math:`k` and :math:`P_0` are conjugate, we must also be able to compute the marginal/prior predictive distribution in one point, i.e. :math:`m(x) = \int k(x | \theta) P_0(d\theta)`, and the conditional predictive distribution :math:`m(x | \textbf{y} ) = \int k(x | \theta) P_0(d\theta | \{y_i: c_i = h\})` [``prior_pred_lpdf``, ``conditional_pred_lpdf``] Moreover, the following utilities are needed: -6. write the current state :math:`\theta_h` into a appropriately defined Protobuf message [``write_state_to_proto``] -7. restore theta_h from a given Protobuf message [``set_state_from_proto``] -8. write the values of the hyperparameters in :math:`P_0` to a Protobuf message [``write_hypers_to_proto``] +f. write the current state :math:`\theta_h` into a appropriately defined Protobuf message [``write_state_to_proto``] +g. restore theta_h from a given Protobuf message [``set_state_from_proto``] +h. write the values of the hyperparameters in :math:`P_0` to a Protobuf message [``write_hypers_to_proto``] In each hierarchy, we also keep track of which data points are allocated to the hierarchy. @@ -42,21 +41,26 @@ The code thus composes of: a virtual class defining the API, a template base cla The class ``AbstractHierarchy`` defines the API, i.e. all the methods that need to be called from outside of a ``Hierarchy`` class. A template class ``BaseHierarchy`` inherits from ``AbstractHierarchy`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. -Instead, child classes must implement: +.. toctree:: + :maxdepth: 1 + :caption: API: hierarchies submodules -1. ``like_lpdf``: evaluates :math:`k(x | \theta_h)` -2. ``marg_lpdf``: evaluates m(x) given some parameters :math:`\theta_h` (could be both the hyperparameters in :math:`P_0` or the paramters given by the full conditionals) -3. ``draw``: samples from :math:`P_0` given the parameters -4. ``clear_summary_statistics``: clears all the summary statistics -5. ``update_hypers``: performs the update of parameters in :math:`P_0` given all the :math:`\theta_h` (passed as a vector of protobuf Messages) -6. ``initialize_state``: initializes the current :math:`\theta_h` given the hyperparameters in :math:`P_0` -7. ``initialize_hypers``: initializes the hyperparameters in :math:`P_0` given their hyperprior -8. ``update_summary_statistics``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy -9. ``get_posterior_parameters``: returns the paramters of the full conditional distribution **possible only when** :math:`P_0` **and** :math:`k` **are conjugate** -10. ``set_state_from_proto`` -11. ``write_state_to_proto`` -12. ``write_hypers_to_proto`` + likelihoods +Instead, child classes must implement: + +a. ``like_lpdf``: evaluates :math:`k(x | \theta_h)` +b. ``marg_lpdf``: evaluates m(x) given some parameters :math:`\theta_h` (could be both the hyperparameters in :math:`P_0` or the paramters given by the full conditionals) +c. ``draw``: samples from :math:`P_0` given the parameters +d. ``clear_summary_statistics``: clears all the summary statistics +e. ``update_hypers``: performs the update of parameters in :math:`P_0` given all the :math:`\theta_h` (passed as a vector of protobuf Messages) +f. ``initialize_state``: initializes the current :math:`\theta_h` given the hyperparameters in :math:`P_0` +g. ``initialize_hypers``: initializes the hyperparameters in :math:`P_0` given their hyperprior +h. ``update_summary_statistics``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy +i. ``get_posterior_parameters``: returns the paramters of the full conditional distribution **possible only when** :math:`P_0` **and** :math:`k` **are conjugate** +j. ``set_state_from_proto`` +k. ``write_state_to_proto`` +l. ``write_hypers_to_proto`` Note that not all of these members are declared virtual in ``AbstractHierarchy`` or ``BaseHierarchy``: this is because virtual members are only the ones that must be called from outside the ``Hierarchy``, the other ones are handled via CRTP. Not having them virtual saves a lot of lookups in the vtables. The ``BaseHierarchy`` class takes 4 template parameters: @@ -66,12 +70,11 @@ The ``BaseHierarchy`` class takes 4 template parameters: 3. ``Hyperparams`` is usually a struct representing the parameters in :math:`P_0` 4. ``Prior`` must be a protobuf object encoding the prior parameters. +.. Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models. -Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models. - -------- -Classes -------- +---------------- +Abstract Classes +---------------- .. doxygenclass:: AbstractHierarchy :project: bayesmix @@ -79,9 +82,11 @@ Classes .. doxygenclass:: BaseHierarchy :project: bayesmix :members: -.. doxygenclass:: ConjugateHierarchy - :project: bayesmix - :members: + +--------------------------------- +Classes for Conjugate Hierarchies +--------------------------------- + .. doxygenclass:: NNIGHierarchy :project: bayesmix :members: @@ -91,3 +96,17 @@ Classes .. doxygenclass:: LinRegUniHierarchy :project: bayesmix :members: + +------------------------------------- +Classes for Non-Conjugate Hierarchies +------------------------------------- + +.. doxygenclass:: NNxIGHierarchy + :project: bayesmix + :members: +.. doxygenclass:: LapNIGHierarchy + :project: bayesmix + :members: +.. doxygenclass:: FAHierarchy + :project: bayesmix + :members: diff --git a/docs/index.rst b/docs/index.rst index 59d19db44..271387414 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,7 +32,8 @@ There are currently three submodules to the ``bayesmix`` library, represented by Further, we employ Protocol buffers for several purposes, including serialization. The list of all protos with their docs is available in the ``protos`` link below. .. toctree:: - :maxdepth: 1 + :maxdepth: 2 + :titlesonly: :caption: API: library submodules algorithms @@ -43,11 +44,15 @@ Further, we employ Protocol buffers for several purposes, including serializatio utils - Tutorials ========= -:doc:`tutorial` +.. toctree:: + :maxdepth: 1 + + tutorial + +.. :doc:`tutorial` Python interface diff --git a/docs/likelihoods.rst b/docs/likelihoods.rst new file mode 100644 index 000000000..78ce5294d --- /dev/null +++ b/docs/likelihoods.rst @@ -0,0 +1,85 @@ +bayesmix/hierarchies/likelihoods + +Likelihoods +=========== + +The ``Likelihood`` sub-module represents the likelihood we have assumed for the data in a given cluster. Each ``Likelihood`` class represents the sampling model + +.. math:: + y_1, \ldots, y_k \mid \bm{\tau} \stackrel{\small\mathrm{iid}}{\sim} f(\cdot \mid \bm{\tau}) + +for a specific choice of the probability density function :math:`f`. + +------------------------- +Main operations performed +------------------------- + +A likelihood object must be able to perform the following operations: + +a. First of all, we require the \code{lpdf()} and \code{lpdf\_grid()} methods, which simply evaluate the loglikelihood in a given point or in a grid of points (also in case of a \emph{dependent} likelihood, i.e., with covariates associated to each observation) [``lpdf()`` and ``lpdf_grid``] +b. In case you want to rely on a Metropolis-like updater, the likelihood needs to evaluation of the likelihood of the whole cluster starting from the vector of unconstrained parameters [``cluster_lpdf_from_unconstrained()``]. Observe that the ``AbstractLikelihood`` class provides two such methods, one returning a ``double`` and one returning a ``stan::math::var``. The latter is used to automatically compute the gradient of the likelihood via Stan's automatic differentiation, if needed. In practice, users do not need to implement both methods separately and can implement only one templated method +c. manage the insertion and deletion of a datum in the cluster [``add_datum`` and ``remove_datum``] +d. update the summary statistics associated to the likelihood [``update_summary_statistics``]. Summary statistics (when available) are used to evaluate the likelihood function on the whole cluster, as well as to perform the posterior updates of :math:`\bm{\tau}`. This usually gives a substantial speedup + +-------------- +Code structure +-------------- + +In principle, the ``Likelihood`` classes are responsible only of evaluating the log-likelihood function given a specific choice of parameters :math:`\bm{\tau}`. +Therefore, a simple inheritance structure would seem appropriate. However, the nature of the parameters :math:`\bm{\tau}` can be very different across different models (think for instance of the difference between the univariate normal and the multivariate normal paramters). As such, we employ CRTP to manage the polymorphic nature of ``Likelihood`` classes. + +The class ``AbstractLikelihood`` defines the API, i.e. all the methods that need to be called from outside of a ``Likelihood`` class. +A template class ``BaseLikelihood`` inherits from ``AbstractLikelihood`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. + +Instead, child classes **must** implement: + +a. ``compute_lpdf``: evaluates :math:`k(x \mid \theta_h)` +b. ``update_sum_stats``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy +c. ``clear_summary_statistics``: clears all the summary statistics +d. ``is_dependent``: defines if the given likelihood depends on covariates +e. ``is_multivariate``: defines if the given likelihood is for multivariate data + +In case the likelihood needs to be used in a Metropolis-like updater, child classes **should** also implement: + +f. ``cluster_lpdf_from_unconstrained``: evaluates :math:`\prod_{i: c_i = h} k(x_i \mid \tilde{\theta}_h)`, where :math:`\tilde{\theta}_h` is the vector of unconstrained parameters. + +---------------- +Abstract Classes +---------------- + +.. doxygenclass:: AbstractLikelihood + :project: bayesmix + :members: +.. doxygenclass:: BaseLikelihood + :project: bayesmix + :members: + +---------------------------------- +Classes for Univariate Likelihoods +---------------------------------- + +.. doxygenclass:: UniNormLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: UniLinRegLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: LaplaceLikelihood + :project: bayesmix + :members: + :protected-members: + +------------------------------------ +Classes for Multivariate Likelihoods +------------------------------------ + +.. doxygenclass:: MultiNormLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: FALikelihood + :project: bayesmix + :members: + :protected-members: diff --git a/docs/utils.rst b/docs/utils.rst index 89c7e4533..45420f464 100644 --- a/docs/utils.rst +++ b/docs/utils.rst @@ -18,17 +18,20 @@ Distribution-related utilities :project: bayesmix ---------------------------------------------- -``Eigen`` input-output and matrix manipulation +``Eigen`` matrix manipulation utilities ---------------------------------------------- .. doxygenfile:: eigen_utils.h :project: bayesmix - :members: + +-------------------------------- +``Eigen`` input-output utilities +-------------------------------- .. doxygenfile:: io_utils.h :project: bayesmix -------------------------- -``protobuf`` input-output -------------------------- +----------------------------------- +``protobuf`` input-output utilities +----------------------------------- .. doxygenfile:: proto_utils.h :project: bayesmix From b7f3658b1abe0c8b6e81232ae62c46072dc45903 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:37:10 +0200 Subject: [PATCH 287/317] Switch-off docker (TO RESTORE) --- docs/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/CMakeLists.txt b/docs/CMakeLists.txt index 5981d2b7a..3318965a8 100644 --- a/docs/CMakeLists.txt +++ b/docs/CMakeLists.txt @@ -82,6 +82,6 @@ install(DIRECTORY ${SPHINX_BUILD} DESTINATION ${CMAKE_INSTALL_DOCDIR}) add_custom_target(document_bayesmix) -add_dependencies(document_bayesmix document_protos) +# add_dependencies(document_bayesmix document_protos) add_dependencies(document_bayesmix Doxygen) add_dependencies(document_bayesmix Sphinx) From a608295a1eeee416ebbaef87a33a93817548fee7 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:37:22 +0200 Subject: [PATCH 288/317] Comment line --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 3501007dd..af16e0bdc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,7 +57,7 @@ def configureDoxyfile(input_dir, output_dir): html_theme = 'haiku' -html_static_path = ['_static'] +# html_static_path = ['_static'] highlight_language = 'cpp' From e3c9239a5066ca16fd609263a30e71d0a3bb02d9 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 16:53:58 +0200 Subject: [PATCH 289/317] addedd missing classes --- docs/algorithms.rst | 3 + docs/hierarchies.rst | 22 +- docs/likelihoods.rst | 2 +- docs/mixings.rst | 3 + docs/protos.html | 6095 ++++++++++++++++++++++-------------------- 5 files changed, 3230 insertions(+), 2895 deletions(-) diff --git a/docs/algorithms.rst b/docs/algorithms.rst index c61efc257..635800167 100644 --- a/docs/algorithms.rst +++ b/docs/algorithms.rst @@ -20,6 +20,9 @@ Algorithms .. doxygenclass:: Neal8Algorithm :project: bayesmix :members: +.. doxygenclass:: SplitMergeAlgorithm + :project: bayesmix + :members: .. doxygenclass:: ConditionalAlgorithm :project: bayesmix :members: diff --git a/docs/hierarchies.rst b/docs/hierarchies.rst index 7f520588a..f5272c1fd 100644 --- a/docs/hierarchies.rst +++ b/docs/hierarchies.rst @@ -6,6 +6,21 @@ Hierarchies In our algorithms, we store a vector of hierarchies, each of which represent a parameter :math:`\theta_h`. The hierarchy implements all the methods needed to update :math:`\theta_h`: sampling from the prior distribution :math:`P_0`, the full-conditional distribution (given the data {:math:`y_i` such that :math:`c_i = h`} ) and so on. +In BayesMix, each choice of :math:`G_0` is implemented in a different ``PriorModel`` object and each choice of :math:k(\cdot \mid \cdot)` in a ``Likelihood`` object, so that it is straightforward to create a new ``Hierarchy`` using one of the already implemented priors or likelihoods. +The sampling from the full conditional of :math:`\theta_h` is performed in an ``Updater`` class. +`State` classes are used to store parameters ``\theta_h`s of every mixture component. +Their main purpose is to handle serialization and de-serialization of the state + +.. toctree:: + :maxdepth: 1 + :caption: API: hierarchies submodules + + likelihoods + prior_models + updaters + states + + ------------------------- Main operations performed ------------------------- @@ -41,12 +56,6 @@ The code thus composes of: a virtual class defining the API, a template base cla The class ``AbstractHierarchy`` defines the API, i.e. all the methods that need to be called from outside of a ``Hierarchy`` class. A template class ``BaseHierarchy`` inherits from ``AbstractHierarchy`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. -.. toctree:: - :maxdepth: 1 - :caption: API: hierarchies submodules - - likelihoods - Instead, child classes must implement: a. ``like_lpdf``: evaluates :math:`k(x | \theta_h)` @@ -72,6 +81,7 @@ The ``BaseHierarchy`` class takes 4 template parameters: .. Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models. + ---------------- Abstract Classes ---------------- diff --git a/docs/likelihoods.rst b/docs/likelihoods.rst index 78ce5294d..72f2d73cd 100644 --- a/docs/likelihoods.rst +++ b/docs/likelihoods.rst @@ -16,7 +16,7 @@ Main operations performed A likelihood object must be able to perform the following operations: -a. First of all, we require the \code{lpdf()} and \code{lpdf\_grid()} methods, which simply evaluate the loglikelihood in a given point or in a grid of points (also in case of a \emph{dependent} likelihood, i.e., with covariates associated to each observation) [``lpdf()`` and ``lpdf_grid``] +a. First of all, we require the ``lpdf()`` and ``lpdf\_grid()`` methods, which simply evaluate the loglikelihood in a given point or in a grid of points (also in case of a \emph{dependent} likelihood, i.e., with covariates associated to each observation) [``lpdf()`` and ``lpdf_grid``] b. In case you want to rely on a Metropolis-like updater, the likelihood needs to evaluation of the likelihood of the whole cluster starting from the vector of unconstrained parameters [``cluster_lpdf_from_unconstrained()``]. Observe that the ``AbstractLikelihood`` class provides two such methods, one returning a ``double`` and one returning a ``stan::math::var``. The latter is used to automatically compute the gradient of the likelihood via Stan's automatic differentiation, if needed. In practice, users do not need to implement both methods separately and can implement only one templated method c. manage the insertion and deletion of a datum in the cluster [``add_datum`` and ``remove_datum``] d. update the summary statistics associated to the likelihood [``update_summary_statistics``]. Summary statistics (when available) are used to evaluate the likelihood function on the whole cluster, as well as to perform the posterior updates of :math:`\bm{\tau}`. This usually gives a substantial speedup diff --git a/docs/mixings.rst b/docs/mixings.rst index 67d2b1074..64f3b593f 100644 --- a/docs/mixings.rst +++ b/docs/mixings.rst @@ -34,6 +34,9 @@ Classes .. doxygenclass:: PitYorMixing :project: bayesmix :members: +.. doxygenclass:: MixtureFiniteMixing + :project: bayesmix + :members: .. doxygenclass:: TruncatedSBMixing :project: bayesmix :members: diff --git a/docs/protos.html b/docs/protos.html index abc4c641b..26b9929a3 100644 --- a/docs/protos.html +++ b/docs/protos.html @@ -3,8 +3,12 @@ Protocol Documentation - - + + - - + -

Protocol Documentation

Table of Contents

- - -
-

algorithm_id.proto

Top -
-

- - - - -

AlgorithmId

-

Enum for the different types of algorithms.

References

[1] R. M. Neal, Markov Chain Sampling Methods for Dirichlet Process Mixture Models. JCGS(2000)

[2] H. Ishwaran and L. F. James, Gibbs Sampling Methods for Stick-Breaking Priors. JASA(2001)

[3] S. Jain and R. M. Neal, A Split-Merge Markov Chain Monte Carlo Procedure for the Dirichlet Process Mixture Model. JCGS (2004)

[4] M. Kalli, J. Griffin and S. G. Walker, Slice sampling mixture models. Stat and Comp. (2011)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_ALGORITHM0

Neal21

Neal's Algorithm 2, see [1]

Neal32

Neal's Algorithm 3, see [1]

Neal83

Neal's Algorithm 8, see [1]

BlockedGibbs4

Ishwaran and James Blocked Gibbs, see [2]

SplitMerge5

Jain and Neal's Split&Merge, see [3]. NOT IMPLEMENTED YET!

Slice6

Slice sampling, see [4]. NOT IMPLEMENTED YET!

- - - - - - - -
-

algorithm_params.proto

Top -
-

- - -

AlgorithmParams

-

Parameters used in the BaseAlgorithm class and childs.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
algo_idstring

Id of the Algorithm. Must match the ones in the AlgorithmId enum

rng_seeduint32

Seed for the random number generator

iterationsuint32

Total number of iterations of the MCMC chain

burninuint32

Number of iterations to discard as burn-in

init_num_clustersuint32

Number of clusters to initialize the algorithm. It may be overridden by conditional mixings for which the number of components is fixed (e.g. TruncatedSBMixing). In this case, this value is ignored.

neal8_n_auxuint32

Number of auxiliary unique values for the Neal8 algorithm

splitmerge_n_restr_gs_updatesuint32

Number of restricted GS scans for each MH step.

splitmerge_n_mh_updatesuint32

Number of MH updates for each iteration of Split and Merge algorithm.

splitmerge_n_full_gs_updatesuint32

Number of full GS scans for each iteration of Split and Merge algorithm.

- - - - - - - - - - - - - -
-

algorithm_state.proto

Top -
-

- - -

AlgorithmState

-

This message represents the state of a Gibbs sampler for

a mixture model. All algorithms must be able to handle this

message, by filling it with the current state of the sampler

in the `get_state_as_proto` method.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
cluster_statesAlgorithmState.ClusterStaterepeated

The state of each cluster

cluster_allocsint32repeated

Vector of allocations into clusters, one for each observation

mixing_stateMixingState

The state of the `Mixing`

iteration_numint32

The iteration number

hierarchy_hypersAlgorithmState.HierarchyHypers

The current values of the hyperparameters of the hierarchy

- - - - - -

AlgorithmState.ClusterState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
uni_ls_stateUniLSState

State of a univariate location-scale family

multi_ls_stateMultiLSState

State of a multivariate location-scale family

lin_reg_uni_ls_stateLinRegUniLSState

State of a linear regression univariate location-scale family

general_stateVector

Just a vector of doubles

fa_stateFAState

State of a Mixture of Factor Analysers

cardinalityint32

How many observations are in this cluster

- - - - - -

AlgorithmState.HierarchyHypers

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fake_priorEmptyPrior

nnig_stateNIGDistribution

nnw_stateNWDistribution

lin_reg_uni_stateMultiNormalIGDistribution

lapnig_stateLapNIGState

fa_stateFAPriorDistribution

- - - - - - - - - - - - - -
-

distribution.proto

Top -
-

- - -

BetaDistribution

-

Parameters defining a beta distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shape_adouble

shape_bdouble

- - - - - -

GammaDistribution

-

Parameters defining a gamma distribution with density

f(x) = x^(shape-1) * exp(-rate * x) / Gamma(shape)

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shapedouble

ratedouble

- - - - - -

InvWishartDistribution

-

Parameters defining an Inverse Wishart distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
deg_freedouble

scaleMatrix

- - - - - -

MultiNormalDistribution

-

Parameters defining a multivariate normal distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

varMatrix

- - - - - -

MultiNormalIGDistribution

-

Parameters for the Normal Inverse Gamma distribution commonly employed in

linear regression models, with density

f(beta, var) = N(beta | mean, var * var_scaling^{-1}) * IG(var | shape, scale)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

var_scalingMatrix

shapedouble

scaledouble

- - - - - -

NIGDistribution

-

Parameters of a Normal Inverse-Gamma distribution

with density

f(x, y) = N(x | mu, y/var_scaling) * IG(y | shape, scale)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

var_scalingdouble

shapedouble

scaledouble

- - - - - -

NWDistribution

-

Parameters of a Normal Wishart distribution

with density

f(x, y) = N(x | mu, (y * var_scaling)^{-1}) * IW(y | deg_free, scale)

where x is a vector and y is a matrix (spd)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

var_scalingdouble

deg_freedouble

scaleMatrix

- - - - - -

UniNormalDistribution

-

Parameters defining a univariate normal distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

- - - - - - - - - - - - - -
-

hierarchy_id.proto

Top -
-

- - - - -

HierarchyId

-

Enum for the different types of Hierarchy.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_HIERARCHY0

NNIG1

Normal - Normal Inverse Gamma

NNW2

Normal - Normal Wishart

LinRegUni3

Linear Regression (univariate response)

LapNIG4

Laplace - Normal Inverse Gamma

FA5

Factor Analysers

- - - - - - - -
-

hierarchy_prior.proto

Top -
-

- - -

EmptyPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fake_fielddouble

- - - - - -

FAPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesFAPriorDistribution

- - - - - -

FAPriorDistribution

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mutildeVector

betaVector

phidouble

alpha0double

quint32

- - - - - -

LapNIGPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesLapNIGState

- - - - - -

LapNIGState

-

Prior for the parameters of the base measure in a Laplace - Normal Inverse Gamma hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

mh_mean_vardouble

mh_log_scale_vardouble

- - - - - -

LinRegUniPrior

-

Prior for the parameters of the base measure in a Normal mixture model with a covariate-dependent

location.

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesMultiNormalIGDistribution

- - - - - -

NNIGPrior

-

Prior for the parameters of the base measure in a Normal-Normal Inverse Gamma hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesNIGDistribution

no prior, just fixed values

normal_mean_priorNNIGPrior.NormalMeanPrior

prior on the mean

ngg_priorNNIGPrior.NGGPrior

prior on the mean, var_scaling, and scale

- - - - - -

NNIGPrior.NGGPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorUniNormalDistribution

var_scaling_priorGammaDistribution

shapedouble

scale_priorGammaDistribution

- - - - - -

NNIGPrior.NormalMeanPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorUniNormalDistribution

var_scalingdouble

shapedouble

scaledouble

- - - - - -

NNWPrior

-

Prior for the parameters of the base measure in a Normal-Normal Wishart hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesNWDistribution

no prior, just fixed values

normal_mean_priorNNWPrior.NormalMeanPrior

prior on the mean

ngiw_priorNNWPrior.NGIWPrior

prior on the mean, var_scaling, and scale

- - - - - -

NNWPrior.NGIWPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorMultiNormalDistribution

var_scaling_priorGammaDistribution

deg_freedouble

scale_priorInvWishartDistribution

- - - - - -

NNWPrior.NormalMeanPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorMultiNormalDistribution

var_scalingdouble

deg_freedouble

scaleMatrix

- - - - - - - - - - - - - -
-

ls_state.proto

Top -
-

- - -

FAState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
muVector

psiVector

etaMatrix

lambdaMatrix

- - - - - -

LinRegUniLSState

-

Parameters of a univariate linear regression

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
regression_coeffsVector

regression coefficients

vardouble

variance of the noise

- - - - - -

MultiLSState

-

Parameters of a multivariate location-scale family of distributions,

parameterized by mean and precision (inverse of variance). For

convenience, we also store the Cholesky factor of the precision matrix.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

precMatrix

prec_cholMatrix

- - - - - -

UniLSState

-

Parameters of a univariate location-scale family of distributions.

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

- - - - - - - - - - - - - -
-

matrix.proto

Top -
-

- - -

Matrix

-

Message representing a matrix of doubles.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
rowsint32

number of rows

colsint32

number of columns

datadoublerepeated

matrix elements

rowmajorbool

if true, the data is read in row-major order

- - - - - -

Vector

-

Message representing a vector of doubles.

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
sizeint32

number of elements in the vector

datadoublerepeated

vector elements

- - - - - - - - - - - - - -
-

mixing_id.proto

Top -
-

- - - - -

MixingId

-

Enum for the different types of Mixing.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_MIXING0

DP1

Dirichlet Process

PY2

Pitman-Yor Process

LogSB3

Logit Stick-Breaking Process

TruncSB4

Truncated Stick-Breaking Process

MFM5

Mixture of finite mixtures

- - - - - - - -
-

mixing_prior.proto

Top -
-

- - -

DPPrior

-

Prior for the concentration parameter of a Dirichlet process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valueDPState

No prior, just a fixed value

gamma_priorDPPrior.GammaPrior

Gamma prior on the total mass

- - - - - -

DPPrior.GammaPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmass_priorGammaDistribution

- - - - - -

LogSBPrior

-

Definition of the parameters of a Logit-Stick Breaking process.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
normal_priorMultiNormalDistribution

Normal prior on the regression coefficients

step_sizedouble

Steps size for the MALA algorithm used for posterior inference (TODO: move?)

num_componentsuint32

Number of components in the process

- - - - - -

MFMPrior

-

Prior for the Poisson rate and Dirichlet parameters of a MFM (Finite Dirichlet) process.

For the moment, we only support fixed values

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valueMFMState

No prior, just a fixed value

- - - - - -

PYPrior

-

Prior for the strength and discount parameters of a Pitman-Yor process.

For the moment, we only support fixed values

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesPYState

- - - - - -

TruncSBPrior

-

Definition of the parameters of a truncated Stick-Breaking process

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
beta_priorsTruncSBPrior.BetaPriors

General stick-breaking distributions

dp_priorTruncSBPrior.DPPrior

Truncated Dirichlet process

py_priorTruncSBPrior.PYPrior

Truncated Pitman-Yor process

mfm_priorTruncSBPrior.MFMPrior

num_componentsuint32

Number of components in the process

- - - - - -

TruncSBPrior.BetaPriors

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
beta_distributionsBetaDistributionrepeated

General stick-breaking distributions

- - - - - -

TruncSBPrior.DPPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

- - - - - -

TruncSBPrior.MFMPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

- - - - - -

TruncSBPrior.PYPrior

-

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
strengthdouble

Truncated Pitman-Yor process

discountdouble

- - - - - - - - - - - - - -
-

mixing_state.proto

Top -
-

- - -

DPState

-

State of a Dirichlet process

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

the total mass of the DP

- - - - - -

LogSBState

-

State of a Logit-Stick Breaking process

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
regression_coeffsMatrix

Num_Components x Num_Features matrix. Each row is the regression coefficients for a component.

- - - - - -

MFMState

-

State of a MFM (Finite Dirichlet) process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
lambdadouble

rate parameter of Poisson prior on number of compunents of the MFM

gammadouble

parameter of the dirichlet distribution for the mixing weights

- - - - - -

MixingState

-

Wrapper of all possible mixing states into a single oneof

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
dp_stateDPState

py_statePYState

log_sb_stateLogSBState

trunc_sb_stateTruncSBState

mfm_stateMFMState

- - - - - -

PYState

-

State of a Pitman-Yor process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
strengthdouble

discountdouble

- - - - - -

TruncSBState

-

State of a truncated sitck breaking process. For convenice we store also the logarithm of the weights

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
sticksVector

logweightsVector

- - - - - - - - - - - - - -
-

semihdp.proto

Top -
-

- - -

SemiHdpParams

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
pseudo_priorSemiHdpParams.PseudoPriorParams

dirichlet_concentrationdouble

rest_allocs_updatestring

Either "full", "metro_base", "metro_dist"

totalmass_restdouble

totalmass_hdpdouble

w_priorSemiHdpParams.WPriorParams

- - - - - -

SemiHdpParams.PseudoPriorParams

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
card_weightdouble

mean_perturb_sddouble

var_perturb_fracdouble

- - - - - -

SemiHdpParams.WPriorParams

-

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shape1double

shape2double

- - - - - -

SemiHdpState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
restaurantsSemiHdpState.RestaurantStaterepeated

groupsSemiHdpState.GroupStaterepeated

tausSemiHdpState.ClusterStaterepeated

cint32repeated

wdouble

- - - - - -

SemiHdpState.ClusterState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
uni_ls_stateUniLSState

multi_ls_stateMultiLSState

lin_reg_uni_ls_stateLinRegUniLSState

cardinalityint32

- - - - - -

SemiHdpState.GroupState

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
cluster_allocsint32repeated

- - - - - -

SemiHdpState.RestaurantState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
theta_starsSemiHdpState.ClusterStaterepeated

n_by_clusint32repeated

table_to_sharedint32repeated

table_to_idioint32repeated

- - - - - - - - - - - - +
+

algorithm_id.proto

+ Top +
+

+ +

AlgorithmId

+

Enum for the different types of algorithms.

+

References

+

+ [1] R. M. Neal, Markov Chain Sampling Methods for Dirichlet Process + Mixture Models. JCGS(2000) +

+

+ [2] H. Ishwaran and L. F. James, Gibbs Sampling Methods for Stick-Breaking + Priors. JASA(2001) +

+

+ [3] S. Jain and R. M. Neal, A Split-Merge Markov Chain Monte Carlo + Procedure for the Dirichlet Process Mixture Model. JCGS (2004) +

+

+ [4] M. Kalli, J. Griffin and S. G. Walker, Slice sampling mixture models. + Stat and Comp. (2011) +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_ALGORITHM0

Neal21

Neal's Algorithm 2, see [1]

Neal32

Neal's Algorithm 3, see [1]

Neal83

Neal's Algorithm 8, see [1]

BlockedGibbs4

Ishwaran and James Blocked Gibbs, see [2]

SplitMerge5 +

+ Jain and Neal's Split&Merge, see [3]. NOT IMPLEMENTED YET! +

+
Slice6

Slice sampling, see [4]. NOT IMPLEMENTED YET!

+ +
+

algorithm_params.proto

+ Top +
+

+ +

AlgorithmParams

+

Parameters used in the BaseAlgorithm class and childs.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
algo_idstring +

+ Id of the Algorithm. Must match the ones in the AlgorithmId enum +

+
rng_seeduint32

Seed for the random number generator

iterationsuint32

Total number of iterations of the MCMC chain

burninuint32

Number of iterations to discard as burn-in

init_num_clustersuint32 +

+ Number of clusters to initialize the algorithm. It may be + overridden by conditional mixings for which the number of + components is fixed (e.g. TruncatedSBMixing). In this case, this + value is ignored. +

+
neal8_n_auxuint32 +

Number of auxiliary unique values for the Neal8 algorithm

+
splitmerge_n_restr_gs_updatesuint32

Number of restricted GS scans for each MH step.

splitmerge_n_mh_updatesuint32 +

+ Number of MH updates for each iteration of Split and Merge + algorithm. +

+
splitmerge_n_full_gs_updatesuint32 +

+ Number of full GS scans for each iteration of Split and Merge + algorithm. +

+
+ +
+

algorithm_state.proto

+ Top +
+

+ +

AlgorithmState

+

This message represents the state of a Gibbs sampler for

+

a mixture model. All algorithms must be able to handle this

+

message, by filling it with the current state of the sampler

+

in the `get_state_as_proto` method.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
cluster_states + AlgorithmState.ClusterState + repeated

The state of each cluster

cluster_allocsint32repeated +

Vector of allocations into clusters, one for each observation

+
mixing_stateMixingState

The state of the `Mixing`

iteration_numint32

The iteration number

hierarchy_hypers + AlgorithmState.HierarchyHypers + +

The current values of the hyperparameters of the hierarchy

+
+ +

+ AlgorithmState.ClusterState +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
uni_ls_stateUniLSState

State of a univariate location-scale family

multi_ls_stateMultiLSState

State of a multivariate location-scale family

lin_reg_uni_ls_stateLinRegUniLSState +

State of a linear regression univariate location-scale family

+
general_stateVector

Just a vector of doubles

fa_stateFAState

State of a Mixture of Factor Analysers

cardinalityint32

How many observations are in this cluster

+ +

+ AlgorithmState.HierarchyHypers +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
general_stateVector

nnig_stateNIGDistribution

nnw_stateNWDistribution

lin_reg_uni_state + MultiNormalIGDistribution +

nnxig_stateNxIGDistribution

fa_state + FAPriorDistribution +

+ +
+

distribution.proto

+ Top +
+

+ +

BetaDistribution

+

Parameters defining a beta distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shape_adouble

shape_bdouble

+ +

GammaDistribution

+

Parameters defining a gamma distribution with density

+

f(x) = x^(shape-1) * exp(-rate * x) / Gamma(shape)

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shapedouble

ratedouble

+ +

InvWishartDistribution

+

Parameters defining an Inverse Wishart distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
deg_freedouble

scaleMatrix

+ +

MultiNormalDistribution

+

Parameters defining a multivariate normal distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

varMatrix

+ +

MultiNormalIGDistribution

+

+ Parameters for the Normal Inverse Gamma distribution commonly employed in +

+

linear regression models, with density

+

+ f(beta, var) = N(beta | mean, var * var_scaling^{-1}) * IG(var | shape, + scale) +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

var_scalingMatrix

shapedouble

scaledouble

+ +

NIGDistribution

+

Parameters of a Normal Inverse-Gamma distribution

+

with density

+

f(x, y) = N(x | mu, y/var_scaling) * IG(y | shape, scale)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

var_scalingdouble

shapedouble

scaledouble

+ +

NWDistribution

+

Parameters of a Normal Wishart distribution

+

with density

+

f(x, y) = N(x | mu, (y * var_scaling)^{-1}) * IW(y | deg_free, scale)

+

where x is a vector and y is a matrix (spd)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

var_scalingdouble

deg_freedouble

scaleMatrix

+ +

NxIGDistribution

+

Parameters of a Normal x Inverse-Gamma distribution

+

with density

+

f(x, y) = N(x | mu, var) * IG(y | shape, scale)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

+ +

UniNormalDistribution

+

Parameters defining a univariate normal distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

+ +
+

hierarchy_id.proto

+ Top +
+

+ +

HierarchyId

+

Enum for the different types of Hierarchy.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_HIERARCHY0

NNIG1

Normal - Normal Inverse Gamma

NNW2

Normal - Normal Wishart

LinRegUni3

Linear Regression (univariate response)

LapNIG4

Laplace - Normal Inverse Gamma

FA5

Factor Analysers

NNxIG6

Normal - Normal x Inverse Gamma

+ +
+

hierarchy_prior.proto

+ Top +
+

+ +

EmptyPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fake_fielddouble

+ +

FAPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_values + FAPriorDistribution +

+ +

FAPriorDistribution

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mutildeVector

betaVector

phidouble

alpha0double

quint32

+ +

LapNIGPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesLapNIGState

+ +

LapNIGState

+

+ Prior for the parameters of the base measure in a Laplace - Normal Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

mh_mean_vardouble

mh_log_scale_vardouble

+ +

LinRegUniPrior

+

+ Prior for the parameters of the base measure in a Normal mixture model + with a covariate-dependent +

+

location.

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_values + MultiNormalIGDistribution +

+ +

NNIGPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNIGDistribution

no prior, just fixed values

normal_mean_prior + NNIGPrior.NormalMeanPrior +

prior on the mean

ngg_priorNNIGPrior.NGGPrior

prior on the mean, var_scaling, and scale

+ +

NNIGPrior.NGGPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + UniNormalDistribution +

var_scaling_priorGammaDistribution

shapedouble

scale_priorGammaDistribution

+ +

NNIGPrior.NormalMeanPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + UniNormalDistribution +

var_scalingdouble

shapedouble

scaledouble

+ +

NNWPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal Wishart + hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNWDistribution

no prior, just fixed values

normal_mean_prior + NNWPrior.NormalMeanPrior +

prior on the mean

ngiw_priorNNWPrior.NGIWPrior

prior on the mean, var_scaling, and scale

+ +

NNWPrior.NGIWPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + MultiNormalDistribution +

var_scaling_priorGammaDistribution

deg_freedouble

scale_prior + InvWishartDistribution +

+ +

NNWPrior.NormalMeanPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + MultiNormalDistribution +

var_scalingdouble

deg_freedouble

scaleMatrix

+ +

NNxIGPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal x Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNxIGDistribution

no prior, just fixed values

+ +
+

ls_state.proto

+ Top +
+

+ +

FAState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
muVector

psiVector

etaMatrix

lambdaMatrix

+ +

LinRegUniLSState

+

Parameters of a univariate linear regression

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
regression_coeffsVector

regression coefficients

vardouble

variance of the noise

+ +

MultiLSState

+

Parameters of a multivariate location-scale family of distributions,

+

parameterized by mean and precision (inverse of variance). For

+

+ convenience, we also store the Cholesky factor of the precision matrix. +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

precMatrix

prec_cholMatrix

+ +

UniLSState

+

Parameters of a univariate location-scale family of distributions.

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

+ +
+

matrix.proto

+ Top +
+

+ +

Matrix

+

Message representing a matrix of doubles.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
rowsint32

number of rows

colsint32

number of columns

datadoublerepeated

matrix elements

rowmajorbool

if true, the data is read in row-major order

+ +

Vector

+

Message representing a vector of doubles.

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
sizeint32

number of elements in the vector

datadoublerepeated

vector elements

+ +
+

mixing_id.proto

+ Top +
+

+ +

MixingId

+

Enum for the different types of Mixing.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_MIXING0

DP1

Dirichlet Process

PY2

Pitman-Yor Process

LogSB3

Logit Stick-Breaking Process

TruncSB4

Truncated Stick-Breaking Process

MFM5

Mixture of finite mixtures

+ +
+

mixing_prior.proto

+ Top +
+

+ +

DPPrior

+

Prior for the concentration parameter of a Dirichlet process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valueDPState

No prior, just a fixed value

gamma_priorDPPrior.GammaPrior

Gamma prior on the total mass

+ +

DPPrior.GammaPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmass_priorGammaDistribution

+ +

LogSBPrior

+

Definition of the parameters of a Logit-Stick Breaking process.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
normal_prior + MultiNormalDistribution +

Normal prior on the regression coefficients

step_sizedouble +

+ Steps size for the MALA algorithm used for posterior inference + (TODO: move?) +

+
num_componentsuint32

Number of components in the process

+ +

MFMPrior

+

+ Prior for the Poisson rate and Dirichlet parameters of a MFM (Finite + Dirichlet) process. +

+

For the moment, we only support fixed values

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valueMFMState

No prior, just a fixed value

+ +

PYPrior

+

+ Prior for the strength and discount parameters of a Pitman-Yor process. +

+

For the moment, we only support fixed values

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesPYState

+ +

TruncSBPrior

+

Definition of the parameters of a truncated Stick-Breaking process

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
beta_priors + TruncSBPrior.BetaPriors +

General stick-breaking distributions

dp_prior + TruncSBPrior.DPPrior +

Truncated Dirichlet process

py_prior + TruncSBPrior.PYPrior +

Truncated Pitman-Yor process

mfm_prior + TruncSBPrior.MFMPrior +

num_componentsuint32

Number of components in the process

+ +

TruncSBPrior.BetaPriors

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
beta_distributionsBetaDistributionrepeated

General stick-breaking distributions

+ +

TruncSBPrior.DPPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

+ +

TruncSBPrior.MFMPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

+ +

TruncSBPrior.PYPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
strengthdouble

Truncated Pitman-Yor process

discountdouble

+ +
+

mixing_state.proto

+ Top +
+

+ +

DPState

+

State of a Dirichlet process

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

the total mass of the DP

+ +

LogSBState

+

State of a Logit-Stick Breaking process

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
regression_coeffsMatrix +

+ Num_Components x Num_Features matrix. Each row is the regression + coefficients for a component. +

+
+ +

MFMState

+

State of a MFM (Finite Dirichlet) process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
lambdadouble +

+ rate parameter of Poisson prior on number of compunents of the MFM +

+
gammadouble +

+ parameter of the dirichlet distribution for the mixing weights +

+
+ +

MixingState

+

Wrapper of all possible mixing states into a single oneof

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
dp_stateDPState

py_statePYState

log_sb_stateLogSBState

trunc_sb_stateTruncSBState

mfm_stateMFMState

+ +

PYState

+

State of a Pitman-Yor process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
strengthdouble

discountdouble

+ +

TruncSBState

+

+ State of a truncated sitck breaking process. For convenice we store also + the logarithm of the weights +

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
sticksVector

logweightsVector

+ +
+

mixture_model.proto

+ Top +
+

+ +

HierarchyPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
nnig_priorNNIGPrior

lapnig_priorLapNIGPrior

nnw_priorNNWPrior

lin_reg_priorLinRegUniPrior

fa_priorFAPrior

+ +

MixingPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
dp_priorDPPrior

py_priorPYPrior

log_sb_priorLogSBPrior

trunc_sb_priorTruncSBPrior

+ +

MixtureModel

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mixingMixingId

hierarchyHierarchyId

mixing_priorMixingPrior

hierarchy_priorHierarchyPrior

+ +
+

semihdp.proto

+ Top +
+

+ +

SemiHdpParams

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
pseudo_prior + SemiHdpParams.PseudoPriorParams +

dirichlet_concentrationdouble

rest_allocs_updatestring +

+ Either "full", "metro_base", "metro_dist" +

+
totalmass_restdouble

totalmass_hdpdouble

w_prior + SemiHdpParams.WPriorParams +

+ +

+ SemiHdpParams.PseudoPriorParams +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
card_weightdouble

mean_perturb_sddouble

var_perturb_fracdouble

+ +

SemiHdpParams.WPriorParams

+

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shape1double

shape2double

+ +

SemiHdpState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
restaurants + SemiHdpState.RestaurantState + repeated

groups + SemiHdpState.GroupState + repeated

taus + SemiHdpState.ClusterState + repeated

cint32repeated

wdouble

+ +

SemiHdpState.ClusterState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
uni_ls_stateUniLSState

multi_ls_stateMultiLSState

lin_reg_uni_ls_stateLinRegUniLSState

cardinalityint32

+ +

SemiHdpState.GroupState

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
cluster_allocsint32repeated

+ +

+ SemiHdpState.RestaurantState +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
theta_stars + SemiHdpState.ClusterState + repeated

n_by_clusint32repeated

table_to_sharedint32repeated

table_to_idioint32repeated

Scalar Value Types

- + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
.proto TypeNotesC++JavaPythonGoC#PHPRuby
.proto TypeNotesC++JavaPythonGoC#PHPRuby
doubledoubledoublefloatfloat64doublefloatFloat
floatfloatfloatfloatfloat32floatfloatFloat
int32Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead.int32intintint32intintegerBignum or Fixnum (as required)
int64Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead.int64longint/longint64longinteger/stringBignum
uint32Uses variable-length encoding.uint32intint/longuint32uintintegerBignum or Fixnum (as required)
uint64Uses variable-length encoding.uint64longint/longuint64ulonginteger/stringBignum or Fixnum (as required)
sint32Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s.int32intintint32intintegerBignum or Fixnum (as required)
sint64Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s.int64longint/longint64longinteger/stringBignum
fixed32Always four bytes. More efficient than uint32 if values are often greater than 2^28.uint32intintuint32uintintegerBignum or Fixnum (as required)
fixed64Always eight bytes. More efficient than uint64 if values are often greater than 2^56.uint64longint/longuint64ulonginteger/stringBignum
sfixed32Always four bytes.int32intintint32intintegerBignum or Fixnum (as required)
sfixed64Always eight bytes.int64longint/longint64longinteger/stringBignum
boolboolbooleanbooleanboolboolbooleanTrueClass/FalseClass
stringA string must always contain UTF-8 encoded or 7-bit ASCII text.stringStringstr/unicodestringstringstringString (UTF-8)
bytesMay contain any arbitrary sequence of bytes.stringByteStringstr[]byteByteStringstringString (ASCII-8BIT)
doubledoubledoublefloatfloat64doublefloatFloat
floatfloatfloatfloatfloat32floatfloatFloat
int32 + Uses variable-length encoding. Inefficient for encoding negative + numbers – if your field is likely to have negative values, use + sint32 instead. + int32intintint32intintegerBignum or Fixnum (as required)
int64 + Uses variable-length encoding. Inefficient for encoding negative + numbers – if your field is likely to have negative values, use + sint64 instead. + int64longint/longint64longinteger/stringBignum
uint32Uses variable-length encoding.uint32intint/longuint32uintintegerBignum or Fixnum (as required)
uint64Uses variable-length encoding.uint64longint/longuint64ulonginteger/stringBignum or Fixnum (as required)
sint32 + Uses variable-length encoding. Signed int value. These more + efficiently encode negative numbers than regular int32s. + int32intintint32intintegerBignum or Fixnum (as required)
sint64 + Uses variable-length encoding. Signed int value. These more + efficiently encode negative numbers than regular int64s. + int64longint/longint64longinteger/stringBignum
fixed32 + Always four bytes. More efficient than uint32 if values are often + greater than 2^28. + uint32intintuint32uintintegerBignum or Fixnum (as required)
fixed64 + Always eight bytes. More efficient than uint64 if values are often + greater than 2^56. + uint64longint/longuint64ulonginteger/stringBignum
sfixed32Always four bytes.int32intintint32intintegerBignum or Fixnum (as required)
sfixed64Always eight bytes.int64longint/longint64longinteger/stringBignum
boolboolbooleanbooleanboolboolbooleanTrueClass/FalseClass
string + A string must always contain UTF-8 encoded or 7-bit ASCII text. + stringStringstr/unicodestringstringstringString (UTF-8)
bytesMay contain any arbitrary sequence of bytes.stringByteStringstr[]byteByteStringstringString (ASCII-8BIT)
- From 1ebf53b6577c4aaa066f7315f70b0ca01f47d558 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 16:54:34 +0200 Subject: [PATCH 290/317] more submodules --- docs/prior_models.rst | 83 +++++++++++++++++++++++++++++++++++++++++++ docs/states.rst | 38 ++++++++++++++++++++ docs/updaters.rst | 77 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+) create mode 100644 docs/prior_models.rst create mode 100644 docs/states.rst create mode 100644 docs/updaters.rst diff --git a/docs/prior_models.rst b/docs/prior_models.rst new file mode 100644 index 000000000..8df7cc049 --- /dev/null +++ b/docs/prior_models.rst @@ -0,0 +1,83 @@ +bayesmix/hierarchies/prior_models + +Prior Models +============ + +A ``PriorModel`` represents the prior for the parameters in the likelihood, i.e. + +.. math:: + \bm{\tau} \sim G_{0} + +with :math:`G_{0}` being a suitable prior on the parameters space. We also allow for more flexible priors adding further level of randomness (i.e. the hyperprior) on the parameter characterizing :math:`G_{0}` + +------------------------- +Main operations performed +------------------------- + +A likelihood object must be able to perform the following operations: + +a. First of all, ``lpdf()`` and ``lpdf_from_unconstrained()`` methods evaluate the log-prior density function at the current state :math:`\bm \tau` or its unconstrained representation. +In particular, ``lpdf_from_unconstrained()`` is needed by Metropolis-like updaters. + +b. The ``sample()`` method generates a draw from the prior distribution. If ``hier_hypers`` is ``nullptr, the prior hyperparameter values are used. +To allow sampling from the full conditional distribution in case of semi-congugate hierarchies, we introduce the ``hier_hypers`` parameter, which is a pointer to a ``Protobuf`` message storing the hierarchy hyperaprameters to use for the sampling. + +c. The ``update_hypers()`` method updates the prior hyperparameters, given the vector of all cluster states. + + +-------------- +Code structure +-------------- + +As for the ``Likelihood`` classes we employ the Curiously Recurring Template Pattern to manage the polymorphic nature of ``PriorModel`` classes. + +The class ``AbstractPriorModel`` defines the API, i.e. all the methods that need to be called from outside of a ``PrioModel`` class. +A template class ``BasePriorModel`` inherits from ``AbstractPriorModel`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. + +Instead, child classes **must** implement: + +a. ``lpdf``: evaluates :math:`G_0(\theta_h)` +b. ``sample``: samples from :math:`G_0` given a hyperparameters (passed as a pointer). If ``hier_hypers`` is ``nullptr``, the prior hyperparameter values are used. +c. ``set_hypers_from_proto``: sets the hyperparameters from a ``Probuf``message +d. ``get_hypers_proto``: returns the hyperparameters as a ``Probuf``message +e. ``initialize_hypers``: provides a default initialization of hyperparameters + +In case you want to use a Metropolis-like updater, child classes **should** also implement: + +f. ``lpdf_from_unconstrained``: evaluates :math:`G_0(\tilde{\theta}_h)`, where :math:`\tilde{\theta}_h` is the vector of unconstrained parameters. + +---------------- +Abstract Classes +---------------- + +.. doxygenclass:: AbstractPriorModel + :project: bayesmix + :members: +.. doxygenclass:: BasePriorModel + :project: bayesmix + :members: + +-------------------- +Non-abstract Classes +-------------------- + +.. doxygenclass:: NIGPriorModel + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NxIGPriorModel + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NWPriorModel + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: MNIGPriorModel + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: FAPriorModel + :project: bayesmix + :members: + :protected-members: diff --git a/docs/states.rst b/docs/states.rst new file mode 100644 index 000000000..6193081b1 --- /dev/null +++ b/docs/states.rst @@ -0,0 +1,38 @@ +bayesmix/hierarchies/likelihoods/states + +States +====== + +``States`` are classes used to store parameters :math:`\theta_h` of every mixture component. +Their main purpose is to handle serialization and de-serialization of the state. +Moreover, they allow to go from the constrained to the unconstrained representation of the parameters (and viceversa) and compute the associated determinant of the Jacobian appearing in the change of density formula. + + +-------------- +Code Structure +-------------- + +All classes must inherit from the `BaseState` class + +.. doxygenclass:: BaseState + :project: bayesmix + :members: + +Depending on the chosen ``Updater``, the unconstrained representation might not be needed, and the methods ``get_unconstrained()``, ``set_from_unconstrained()`` and ``log_det_jac()`` might never be called. +Therefore, we do not force users to implement them. +Instead, the ``set_from_proto()`` and ``get_as_proto()`` are fundamental as they allow the interaction with Google's Protocol Buffers library. + +------------- +State Classes +------------- + +.. doxygenclass:: UniLSState + :project: bayesmix + :members: +.. doxygenclass:: MultiLSState + :project: bayesmix + :members: +.. doxygenclass:: FAState + :project: bayesmix + :members: + :protected-members: diff --git a/docs/updaters.rst b/docs/updaters.rst new file mode 100644 index 000000000..9e00718ce --- /dev/null +++ b/docs/updaters.rst @@ -0,0 +1,77 @@ +bayesmix/hierarchies/updaters + +Updaters +======== + +An ``Updater`` implements the machinery to provide a sampling from the full conditional distribution of a given hierarchy. + +The only operation performed is ``draw`` that samples from the full conditional, either exactly or via Markov chain Monte Carlo. + +.. doxygenclass:: AbstractUpdater + :project: bayesmix + :members: + +-------------- +Code Structure +-------------- + +We distinguish between semi-conjugate updaters and the metropolis-like updaters. + + +Semi Conjugate Updaters +----------------------- + +A semi-conjugate updater can be used when the full conditional distribution has the same form of the prior. Therefore, to sample from the full conditional, it is sufficient to call the ``draw`` method of the prior, but with an updated set of hyperparameters. + +The class ``SemiConjugateUpdater`` defines the API + +.. doxygenclass:: SemiConjugateUpdater + :project: bayesmix + :members: + +Classes inheriting from this one should only implement the ``compute_posterior_hypers(...)`` member function. + + +Metropolis-like Updaters +------------------------ + +A Metropolis updater uses the Metropolis-Hastings algorithm (or its variations) to sample from the full conditional density. + +.. doxygenclass:: MetropolisUpdater + :project: bayesmix + :members: + + +Classes inheriting from this one should only implement the ``sample_proposal(...)`` method, which samples from the porposal distribution, and the ``proposal_lpdf`` one, which evaluates the proposal density log-probability density function. + +--------------- +Updater Classes +--------------- + +.. doxygenclass:: RandomWalkUpdater + :project: bayesmix + :members: +.. doxygenclass:: MalaUpdater + :project: bayesmix + :members: + +.. doxygenclass:: NNIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NNxIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NNWUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: MNIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: FAUpdater + :project: bayesmix + :members: + :protected-members: From 2bcad4ee5ac8765fb69bbc0f9a935af31e81e792 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 17:57:09 +0200 Subject: [PATCH 291/317] latexifying stuff --- docs/algorithms.rst | 2 +- docs/prior_models.rst | 8 +++- docs/states.rst | 10 +++-- src/algorithms/conditional_algorithm.h | 20 +++++----- src/algorithms/marginal_algorithm.h | 24 ++++++------ src/hierarchies/fa_hierarchy.h | 12 +++--- src/hierarchies/lapnig_hierarchy.h | 15 +++++--- .../likelihoods/states/base_state.h | 18 ++++----- src/hierarchies/likelihoods/states/fa_state.h | 13 ++++--- src/hierarchies/nnig_hierarchy.h | 11 ++++-- src/hierarchies/nnw_hierarchy.h | 15 +++++--- src/hierarchies/priors/fa_prior_model.h | 12 +++--- src/hierarchies/priors/mnig_prior_model.h | 6 ++- src/hierarchies/priors/nig_prior_model.h | 12 +++--- src/hierarchies/priors/nw_prior_model.h | 11 ++++-- src/hierarchies/priors/nxig_prior_model.h | 6 ++- src/hierarchies/updaters/nnxig_updater.h | 9 +++-- src/mixings/dirichlet_mixing.h | 21 +++++----- src/mixings/logit_sb_mixing.h | 23 +++++------ src/mixings/mixture_finite_mixing.h | 31 ++++++++------- src/mixings/pityor_mixing.h | 22 +++++------ src/mixings/truncated_sb_mixing.h | 38 ++++++++++--------- 22 files changed, 191 insertions(+), 148 deletions(-) diff --git a/docs/algorithms.rst b/docs/algorithms.rst index 635800167..e15aff394 100644 --- a/docs/algorithms.rst +++ b/docs/algorithms.rst @@ -20,7 +20,7 @@ Algorithms .. doxygenclass:: Neal8Algorithm :project: bayesmix :members: -.. doxygenclass:: SplitMergeAlgorithm +.. doxygenclass:: SplitAndMergeAlgorithm :project: bayesmix :members: .. doxygenclass:: ConditionalAlgorithm diff --git a/docs/prior_models.rst b/docs/prior_models.rst index 8df7cc049..002ab1f85 100644 --- a/docs/prior_models.rst +++ b/docs/prior_models.rst @@ -38,8 +38,8 @@ Instead, child classes **must** implement: a. ``lpdf``: evaluates :math:`G_0(\theta_h)` b. ``sample``: samples from :math:`G_0` given a hyperparameters (passed as a pointer). If ``hier_hypers`` is ``nullptr``, the prior hyperparameter values are used. -c. ``set_hypers_from_proto``: sets the hyperparameters from a ``Probuf``message -d. ``get_hypers_proto``: returns the hyperparameters as a ``Probuf``message +c. ``set_hypers_from_proto``: sets the hyperparameters from a ``Probuf`` message +d. ``get_hypers_proto``: returns the hyperparameters as a ``Probuf`` message e. ``initialize_hypers``: provides a default initialization of hyperparameters In case you want to use a Metropolis-like updater, child classes **should** also implement: @@ -65,18 +65,22 @@ Non-abstract Classes :project: bayesmix :members: :protected-members: + .. doxygenclass:: NxIGPriorModel :project: bayesmix :members: :protected-members: + .. doxygenclass:: NWPriorModel :project: bayesmix :members: :protected-members: + .. doxygenclass:: MNIGPriorModel :project: bayesmix :members: :protected-members: + .. doxygenclass:: FAPriorModel :project: bayesmix :members: diff --git a/docs/states.rst b/docs/states.rst index 6193081b1..b20fddaac 100644 --- a/docs/states.rst +++ b/docs/states.rst @@ -14,7 +14,7 @@ Code Structure All classes must inherit from the `BaseState` class -.. doxygenclass:: BaseState +.. doxygenclass:: State::BaseState :project: bayesmix :members: @@ -26,13 +26,15 @@ Instead, the ``set_from_proto()`` and ``get_as_proto()`` are fundamental as they State Classes ------------- -.. doxygenclass:: UniLSState +.. doxygenclass:: State::UniLS :project: bayesmix :members: -.. doxygenclass:: MultiLSState + +.. doxygenclass:: State::MultiLS :project: bayesmix :members: -.. doxygenclass:: FAState + +.. doxygenclass:: State::FA :project: bayesmix :members: :protected-members: diff --git a/src/algorithms/conditional_algorithm.h b/src/algorithms/conditional_algorithm.h index 988780c96..d83d5f7af 100644 --- a/src/algorithms/conditional_algorithm.h +++ b/src/algorithms/conditional_algorithm.h @@ -12,15 +12,17 @@ //! This template class implements a generic Gibbs sampling conditional //! algorithm as the child of the `BaseAlgorithm` class. //! A mixture model sampled from a conditional algorithm can be expressed as -//! x_i | c_i, phi_1, ..., phi_k ~ f(x_i|phi_(c_i)) (data likelihood); -//! phi_1, ... phi_k ~ G (unique values); -//! c_1, ... c_n | w_1, ..., w_k ~ Cat(w_1, ... w_k) (cluster allocations); -//! w_1, ..., w_k ~ p(w_1, ..., w_k) (mixture weights) -//! where f(x | phi_j) is a density for each value of phi_j, the c_i take -//! values in {1, ..., k} and w_1, ..., w_k are nonnegative weights whose sum -//! is a.s. 1, i.e. p(w_1, ... w_k) is a probability distribution on the k-1 -//! dimensional unit simplex). -//! In this library, each phi_j is represented as an `Hierarchy` object (which +//! \f[ +//! x_i | c_i, \theta_1, ..., \theta_k & \sim f(x_i|\theta_(c_i)) \\ +//! \theta_1, ..., \theta_k & \sim G_0 \\ +//! c_1, ... c_n | w_1, ..., w_k & \sim \text{Cat}(w_1, ... w_k) \\ +//! w_1, ..., w_k & \sim p(w_1, ..., w_k) +//! \f] +//! where \f$f(x | \theta_j)\f$ is a density for each value of \f$\theta_j\f$, +//! \f$c_i\f$ take values in \f${1, ..., k}\f$ and \f$w_1, ..., w_k\f$ are +//! nonnegative weights whose sum is a.s. 1, i.e. \f$p(w_1, ... w_k)\f$ is a +//! probability distribution on the k-1 dimensional unit simplex). n this +//! library, each \f$\theta_j\f$ is represented as an `Hierarchy` object (which //! inherits from `AbstractHierarchy`), that also holds the information related //! to the base measure `G` is (see `AbstractHierarchy`). //! The weights (w_1, ..., w_k) are represented as a `Mixing` object, which diff --git a/src/algorithms/marginal_algorithm.h b/src/algorithms/marginal_algorithm.h index 34af2167a..d0c880e10 100644 --- a/src/algorithms/marginal_algorithm.h +++ b/src/algorithms/marginal_algorithm.h @@ -13,17 +13,19 @@ //! This template class implements a generic Gibbs sampling marginal algorithm //! as the child of the `BaseAlgorithm` class. //! A mixture model sampled from a Marginal Algorithm can be expressed as -//! x_i | c_i, phi_1, ..., phi_k ~ f(x_i|phi_(c_i)) (data likelihood); -//! phi_1, ... phi_k ~ G (unique values); -//! c_1, ... c_n ~ EPPF(c_1, ... c_n) (cluster allocations); -//! where f(x | phi_j) is a density for each value of phi_j and the c_i take -//! values in {1, ..., k}. -//! Depending on the actual implementation, the algorithm might require -//! the kernel/likelihood f(x | phi) and G(phi) to be conjugagte or not. -//! In the former case, a `ConjugateHierarchy` must be specified. -//! In this library, each phi_j is represented as an `Hierarchy` object (which -//! inherits from `AbstractHierarchy`), that also holds the information related -//! to the base measure `G` is (see `AbstractHierarchy`). The EPPF is instead +//! \f[ +//! x_i | c_i, \theta_1, ..., \theta_k & \sim f(x_i|\theta_(c_i)) \\ +//! \theta_1, ..., \theta_k & \sim G_0 \\ +//! c_1, ... c_n & \sim EPPF(c_1, ... c_n) +//! \f] +//! where \f$f(x | \theta_j)\f$ is a density for each value of \f$\theta_j\f$ +//! and \f$c_i\f$ take values in \f${1, ..., k}\f$. Depending on the actual +//! implementation, the algorithm might require the kernel/likelihood \f$f(x | +//! \theta)\f$ and \f$G_0(phi)\f$ to be conjugagte or not. In the former case, +//! a conjugate hierarchy must be specified. In this library, each +//! \f$\theta_j\f$ is represented as an `Hierarchy` object (which inherits from +//! `AbstractHierarchy`), that also holds the information related to the base +//! measure \f$G_0\f$ is (see `AbstractHierarchy`). The EPPF is instead //! represented as a `Mixing` object, which inherits from `AbstractMixing`. //! //! The state of a marginal algorithm only consists of allocations and unique diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index f98dbc7c5..c0f8ecd63 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -15,11 +15,13 @@ //! of the covariance matrix (see the `FAHierarchy` class for details). The //! likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma //! centering distribution (see the `FAPriorModel` class for details). That is: -//! f(x_i|mu,Sigma,Lambda) = N(mu,Sigma+Lambda*Lambda^T) -//! mu ~ N(mu0,psi*I) -//! Lambda ~ DL(alpha) -//! Sigma = diag(sig1^2,...,sigp^2) -//! sigj^2 ~ IG(a,b) for j=1,...,p +//! \f[ +//! f(x_i| \mu, \Sigma, \Lambda) &= N(\mu, \Sigma + \Lambda \Lambda^T) \\ +//! \mu &\sim N_p(\tilde \mu, \psi I) \\ +//! \Lambda &\sim DL(\alpha) \\ +//! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ +//! \sigma^2_j &\sim IG(a,b) \quad j=1,...,p +//! \f] //! where Lambda is the latent score matrix (size p x d with d << p) and //! DL(alpha) is the Laplace-Dirichlet distribution. //! See Bhattacharya et al. (2015) for further details. diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h index 000b60bd5..d4803a5c8 100644 --- a/src/hierarchies/lapnig_hierarchy.h +++ b/src/hierarchies/lapnig_hierarchy.h @@ -13,13 +13,16 @@ //! according to a laplace likelihood (see the `LaplaceLikelihood` class for //! deatils).The likelihood parameters have a Normal x InverseGamma centering //! distribution (see the `NxIGPriorModel` class for details). That is: -//! f(x_i|mu,lambda) = Laplace(mu,sqrt(var/2)) -//! mu ~ N(mu0,sig0^2) -//! var ~ IG(alpha0,beta0) +//! \f[ +//! f(x_i|\mu,\sigma^2) &= Laplace(\mu,\sqrt(\sigma^2/2))\\ +//! \mu &\sim N(\mu_0,\eta^2) \\ +//! \sigma^2 ~ IG(a, b) +//! \f] //! The state is composed of mean and variance (thus the scale for the Laplace -//! distribution is sqrt(var / 2)). The state hyperparameters are (mu_0, -//! sig0^2, alpha0, beta0), all scalar values. Note that this hierarchy is NOT -//! conjugate, thus the marginal distribution is not available in closed form. +//! distribution is \f$ \sqrt(\sigma^2 / 2)) \f$. The state hyperparameters are +//! \f$(mu_0, \sigma^2, a, b)\f$, all scalar values. Note that this hierarchy +//! is NOT conjugate, thus the marginal distribution is not available in closed +//! form. class LapNIGHierarchy : public BaseHierarchy { diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index cb0744531..4217913d1 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -14,14 +14,17 @@ //! are distributed according to a multivariate normal likelihood (see the //! `MultiNormLikelihood` for details). The likelihood parameters have a //! Normal-Wishart centering distribution (see the `NWPriorModel` class for -//! details). That is: f(x_i|mu,tau) = N(mu,tau^{-1}) -//! (mu,tau) ~ NW(mu0, lambda0, tau0, nu0) +//! details). That is: +//! \f[ +//! f(x_i|\mu,\Sigma) &= N(\mu,\Sigma^{-1}) \\ +//! (\mu,\Sigma) &\sim NW(\mu_0, \lambda, \Psi_0, \nu_0) +//! \f] //! The state is composed of mean and precision matrix. The Cholesky factor and //! log-determinant of the latter are also included in the container for -//! efficiency reasons. The state's hyperparameters are (mu0, lambda0, tau0, -//! nu0), which are respectively vector, scalar, matrix, and scalar. Note that -//! this hierarchy is conjugate, thus the marginal distribution is available in -//! closed form. +//! efficiency reasons. The state's hyperparameters are \f$(\mu_0, \lambda, +//! \Psi_0, \nu_0)\f$, which are respectively vector, scalar, matrix, and +//! scalar. Note that this hierarchy is conjugate, thus the marginal +//! distribution is available in closed form. class NNWHierarchy : public BaseHierarchy { diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index f35ac03fa..d1689d74f 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -10,11 +10,13 @@ #include "hyperparams.h" #include "src/utils/rng.h" -//! A prior model for the factor analyzers likelihood, that is -//! mu ~ N_p(mutilde, psi*I) -//! Lambda ~ DL(alpha) -//! Sigma = diag(sigsq_1,...,sigsq_p) -//! sigsq_j ~ IG(a,b) j=1,...,p +//! A priormodel for the factor analyzers likelihood, that is +//! \f[ +//! \mu &\sim N_p(\tilde \mu, \psi I) \\ +//! \Lambda &\sim DL(\alpha) \\ +//! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ +//! \sigma^2_j &\sim IG(a,b) \quad j=1,...,p +//! \f] //! Where DL is the Dirichlet-Laplace distribution. See Bhattacharya A., Pati //! D, Pillai N.S., Dunson D.B. (2015). JASA 110(512), 1479–1490 for details. diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index 885993db9..1202ab551 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -11,8 +11,10 @@ #include "src/utils/rng.h" //! A conjugate prior model for the scalar linear regression likelihood, i.e. -//! reg_coeffs | var ~ N_p(mu, var * Lambda^-1) -//! var ~ IG(a,b) +//! \f[ +//! \beta | \sigma^2 & \sim N_p(\mu, \sigma^2 \Lambda^-1) \\ +//! \sigma^2 & \sim IG(a,b) +//! \f] class MNIGPriorModel : public BasePriorModel { public: diff --git a/src/mixings/logit_sb_mixing.h b/src/mixings/logit_sb_mixing.h index 228f4da40..51fd33fff 100644 --- a/src/mixings/logit_sb_mixing.h +++ b/src/mixings/logit_sb_mixing.h @@ -12,15 +12,22 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" +namespace LogitSB { +struct State { + Eigen::MatrixXd regression_coeffs, precision; +}; +}; // namespace LogitSB + //! Class that represents the logit stick-breaking process indroduced in Rigon //! and Durante (2020). //! That is, a prior for weights (w_1,...,w_H), depending on covariates x in //! R^p, in the H-1 dimensional unit simplex, defined as follows: -//! w_1(x) = v_1(x) -//! w_j(x) = v_j(x) (1 - v_1(x)) ... (1 - v_{j-1}(x)), for j=2, ... H-1 -//! w_H(x) = 1 - (w_1(x) + w_2 + ... + w_{H-1}(x)) -//! and -//! v_j(x) = 1 / exp(- ), for j = 1, ..., H-1 +//! \f[ +//! w_1 &= v_1\\ +//! w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ +//! w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) \\ +//! v_j(x) &= 1 / exp(- <\alpha_j, x> ), for j = 1, ..., H-1 +//! \f] //! //! The main difference with the mentioned paper is that the authors propose a //! Gibbs sampler in which the full conditionals are available in close form @@ -30,12 +37,6 @@ //! For more information about the class, please refer instead to base classes, //! `AbstractMixing` and `BaseMixing`. -namespace LogitSB { -struct State { - Eigen::MatrixXd regression_coeffs, precision; -}; -}; // namespace LogitSB - class LogitSBMixing : public BaseMixing { public: diff --git a/src/mixings/mixture_finite_mixing.h b/src/mixings/mixture_finite_mixing.h index c1adc3dfa..0a38752dd 100644 --- a/src/mixings/mixture_finite_mixing.h +++ b/src/mixings/mixture_finite_mixing.h @@ -12,18 +12,27 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" +namespace Mixture_Finite { +struct State { + double lambda, gamma; +}; +}; // namespace Mixture_Finite + //! Class that represents the Mixture of Finite Mixtures (MFM) [1] //! The basic idea is to take usual finite mixture model with Dirichlet weights //! and put a prior (Poisson) on the number of components. The EPPF induced by -//! MFM depends on a Dirichlet parameter 'gamma' and number V_n(t), where -//! V_n(t) depends on the Poisson rate parameter 'lambda'. -//! V_n(t) = sum_{k=1}^{inf} ( k_(t)*p_K(k) / (gamma*k)^(n) ) +//! MFM depends on a Dirichlet parameter 'gamma' and number \f$V_n(t)\f$, where +//! \f$V_n(t)\f$ depends on the Poisson rate parameter 'lambda'. +//! \f[ +//! V_n(t) = \sum_{k=1}^{\infty} ( k_(t)p_K(k) / (\gamma*k)^(n) ) +//! \f] //! Given a clustering of n elements into k clusters, each with cardinality -//! n_j, j=1, ..., k, the EPPF of the MFM gives the following probabilities for -//! the cluster membership of the (n+1)-th observation: -//! denominator = n_j + gamma / (n + gamma*(n_clust + V[n_clust+1]/V[n_clust])) -//! p(j-th cluster | ...) = (n_j + gamma) / denominator -//! p(k+1-th cluster | ...) = V[n_clust+1]/V[n_clust]*gamma / denominator +//! \f$n_j, j=1, ..., k\f$, the EPPF of the MFM gives the following +//! probabilities for the cluster membership of the (n+1)-th observation: \f[ +//! p(\text{j-th cluster} | ...) &= (n_j + \gamma) / D \\ +//! p(\text{k+1-th cluster} | ...) &= V[k+1]/V[k] \gamma / D \\ +//! D &= n_j + \gamma / (n + \gamma * (k + V[k+1]/V[k])) +//! \f] //! For numerical reasons each value of V is multiplied with a constant C //! computed as the first term of the series of V_n[0]. //! For more information about the class, please refer instead to base @@ -31,12 +40,6 @@ //! [1] "Mixture Models with a Prior on the Number of Components", J.W.Miller //! and M.T.Harrison, 2015, arXiv:1502.06241v1 -namespace Mixture_Finite { -struct State { - double lambda, gamma; -}; -}; // namespace Mixture_Finite - class MixtureFiniteMixing : public BaseMixing { diff --git a/src/mixings/pityor_mixing.h b/src/mixings/pityor_mixing.h index 5ebec6958..6dbedf569 100644 --- a/src/mixings/pityor_mixing.h +++ b/src/mixings/pityor_mixing.h @@ -12,26 +12,26 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" +namespace PitYor { +struct State { + double strength, discount; +}; +}; // namespace PitYor + //! Class that represents the Pitman-Yor process (PY) in Pitman and Yor (1997). //! The EPPF induced by the PY depends on a `strength` parameter M and a //! `discount` paramter d. //! Given a clustering of n elements into k clusters, each with cardinality -//! n_j, j=1, ..., k, the EPPF of the PY gives the following probabilities for -//! the cluster membership of the (n+1)-th observation: -//! p(j-th cluster | ...) \propto (n_j - d) -//! p(k+1-th cluster | ...) \propto M + k * d -//! +//! \f$ n_j, j=1, ..., k \f$, the EPPF of the PY gives the following +//! probabilities for the cluster membership of the (n+1)-th observation: \f[ +//! p(\text{j-th cluster} | ...) \propto (n_j - d) \\ +//! p(\text{new cluster} | ...) \propto M + k d +//! \f] //! When `discount=0`, the EPPF of the PY process coincides with the one of the //! DP with totalmass = strength. //! For more information about the class, please refer instead to base classes, //! `AbstractMixing` and `BaseMixing`. -namespace PitYor { -struct State { - double strength, discount; -}; -}; // namespace PitYor - class PitYorMixing : public BaseMixing { public: diff --git a/src/mixings/truncated_sb_mixing.h b/src/mixings/truncated_sb_mixing.h index 120a6cc20..3af68249b 100644 --- a/src/mixings/truncated_sb_mixing.h +++ b/src/mixings/truncated_sb_mixing.h @@ -12,30 +12,32 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Class that represents a truncated stick-breaking process, as shown in -//! Ishwaran and James (2001). -//! -//! A truncated stick-breaking process is a prior for weights (w_1,...,w_H) in -//! the H-1 dimensional unit simplex, and is defined as follows: -//! w_1 = v_1 -//! w_j = v_j (1 - v_1) ... (1 - v_{j-1}), for j=1, ... H-1 -//! w_H = 1 - (w_1 + w_2 + ... + w_{H-1}) -//! The v_j's are called sticks and we assume them to be independently -//! distributed as v_j ~ Beta(a_j, b_j). -//! -//! When a_j = 1 and b_j = M, the stick-breaking process is a truncation of the -//! stick-breaking representation of the DP. -//! When a_j = 1 - d and b_j = M + i*d, it is the trunctation of a PY process. -//! Its state is composed of the weights w_j in log-scale and the sticks v_j. -//! For more information about the class, please refer instead to base classes, -//! `AbstractMixing` and `BaseMixing`. - namespace TruncSB { struct State { Eigen::VectorXd sticks, logweights; }; }; // namespace TruncSB +//! Class that represents a truncated stick-breaking process, as shown in +//! Ishwaran and James (2001). +//! +//! A truncated stick-breaking process is a prior for weights +//! \f$(w_1,...,w_H)\f$ in the H-1 dimensional unit simplex, and is defined as +//! follows: \f[ +//! w_1 &= v_1\\ +//! w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ +//! w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) +//! \f] +//! The \f$v_j\f$'s are called sticks and we assume them to be independently +//! distributed as \f$v_j \sim \text{Beta}(a_j, b_j)\f$. +//! +//! When \f$a_j = 1\f$ and \f$b_j = M\f$, the stick-breaking process is a +//! truncation of the stick-breaking representation of the DP. When \f$a_j = 1 +//! - d\f$ and \f$b_j = M + id \f$, it is the trunctation of a PY process. Its +//! state is composed of the weights \f$w_j\f$ in log-scale and the sticks +//! \f$v_j\f$. For more information about the class, please refer instead to +//! base classes, `AbstractMixing` and `BaseMixing`. + class TruncatedSBMixing : public BaseMixing { public: From c8ac99791a6d37b46a8e7e83994687ceff810815 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 18:02:18 +0200 Subject: [PATCH 292/317] addressing bruno's comments --- src/hierarchies/base_hierarchy.h | 8 ++++---- src/hierarchies/lapnig_hierarchy.h | 4 ++-- src/hierarchies/lin_reg_uni_hierarchy.h | 4 ++-- src/hierarchies/nnig_hierarchy.h | 4 ++-- src/hierarchies/nnw_hierarchy.h | 8 ++++---- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index ad71a0c93..86ea08a0a 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -343,8 +343,8 @@ class BaseHierarchy : public AbstractHierarchy { virtual void initialize_state() = 0; //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf virtual double marg_lpdf(ProtoHypersPtr hier_params, @@ -359,8 +359,8 @@ class BaseHierarchy : public AbstractHierarchy { } //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @param covariate Covariate vector associated to datum //! @return The evaluation of the lpdf diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h index d4803a5c8..92b66fdd9 100644 --- a/src/hierarchies/lapnig_hierarchy.h +++ b/src/hierarchies/lapnig_hierarchy.h @@ -10,8 +10,8 @@ //! Laplace Normal-InverseGamma hierarchy for univariate data. //! This class represents a hierarchical model where data are distributed -//! according to a laplace likelihood (see the `LaplaceLikelihood` class for -//! deatils).The likelihood parameters have a Normal x InverseGamma centering +//! according to a Laplace likelihood (see the `LaplaceLikelihood` class for +//! deatils). The likelihood parameters have a Normal x InverseGamma centering //! distribution (see the `NxIGPriorModel` class for details). That is: //! \f[ //! f(x_i|\mu,\sigma^2) &= Laplace(\mu,\sqrt(\sigma^2/2))\\ diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index bd208ef01..7afd951f4 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -56,8 +56,8 @@ class LinRegUniHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @param covariate Covariate vectors associated to data //! @return The evaluation of the lpdf diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index cb6a0a545..161882ae1 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -48,8 +48,8 @@ class NNIGHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 4217913d1..6dfbebb0e 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -56,8 +56,8 @@ class NNWHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, @@ -73,8 +73,8 @@ class NNWHierarchy //! Helper function that computes the predictive parameters for the //! multivariate t distribution from the current hyperparameter values. It is //! used to efficiently compute the log-marginal distribution of data. - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @return A `HyperParam` object with the predictive parameters HyperParams get_predictive_t_parameters(ProtoHypersPtr hier_params) const { auto params = hier_params->nnw_state(); From b8903c8bfc116dc4211612688278d7b9024fea19 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 18:03:47 +0200 Subject: [PATCH 293/317] rebuilding protos --- docs/CMakeLists.txt | 2 +- src/hierarchies/priors/fa_prior_model.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/CMakeLists.txt b/docs/CMakeLists.txt index 3318965a8..5981d2b7a 100644 --- a/docs/CMakeLists.txt +++ b/docs/CMakeLists.txt @@ -82,6 +82,6 @@ install(DIRECTORY ${SPHINX_BUILD} DESTINATION ${CMAKE_INSTALL_DOCDIR}) add_custom_target(document_bayesmix) -# add_dependencies(document_bayesmix document_protos) +add_dependencies(document_bayesmix document_protos) add_dependencies(document_bayesmix Doxygen) add_dependencies(document_bayesmix Sphinx) diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index d1689d74f..dfec2a731 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -15,7 +15,7 @@ //! \mu &\sim N_p(\tilde \mu, \psi I) \\ //! \Lambda &\sim DL(\alpha) \\ //! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ -//! \sigma^2_j &\sim IG(a,b) \quad j=1,...,p +//! \sigma^2_j &\sim IG(a,b) \quadq j=1,...,p //! \f] //! Where DL is the Dirichlet-Laplace distribution. See Bhattacharya A., Pati //! D, Pillai N.S., Dunson D.B. (2015). JASA 110(512), 1479–1490 for details. From 469064eae943a853ea637a8e461eead6e54fd750 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 23:30:46 +0200 Subject: [PATCH 294/317] Fix docs --- .../likelihoods/states/base_state.h | 44 ++++++++++--------- src/hierarchies/likelihoods/states/fa_state.h | 33 ++++++++------ 2 files changed, 42 insertions(+), 35 deletions(-) diff --git a/src/hierarchies/likelihoods/states/base_state.h b/src/hierarchies/likelihoods/states/base_state.h index bc988493d..dbaae301b 100644 --- a/src/hierarchies/likelihoods/states/base_state.h +++ b/src/hierarchies/likelihoods/states/base_state.h @@ -9,26 +9,28 @@ namespace State { -//! Abstract base class for a generic state -//! -//! Given a statistical model with likelihood \f$ L(y|tau) \f$ and prior \f$ -//! p(\tau) \f$ a State class represents the value of tau at a certain MCMC -//! iteration. In addition, each instance stores the cardinality of the number -//! of observations in the model. -//! -//! State classes inheriting from this one should implement the methods -//! `set_from_proto()` and `to_proto()`, that are used to deserialize from -//! (and serialize to) a `bayesmix::AlgorithmState::ClusterState` -//! protocol buffer message. -//! -//! Optionally, each state can have an "unconstrained" representation, -//! where a bijective transformation B is applied to \f$ \tau \f$, so that -//! the image of B is \f$ R^d \f$ for some d. -//! This is essential for the default updaters such as `RandomWalkUpdater` -//! and `MalaUpdater` to work, but is not necessary for other model-specific -//! updaters. -//! If such a representation is needed, child classes should also implement -//! `get_unconstrained()`, `set_from_unconstrained()`, and `log_det_jac()`. +/** + * Abstract base class for a generic state + * + * Given a statistical model with likelihood \f$ L(y \mid \tau) \f$ and prior + * \f$ p(\tau) \f$ a State class represents the value of tau at a certain MCMC + * iteration. In addition, each instance stores the cardinality of the number + * of observations in the model. + * + * State classes inheriting from this one should implement the methods + * `set_from_proto()` and `to_proto()`, that are used to deserialize from + * (and serialize to) a `bayesmix::AlgorithmState::ClusterState` + * protocol buffer message. + * + * Optionally, each state can have an "unconstrained" representation, + * where a bijective transformation \f$ B \f$ is applied to \f$ \tau \f$, so + * that the image of \f$ B \f$ is in \f$ \mathbb{R}^d \f$ for some \f$ d \f$. + * This is essential for the default updaters such as `RandomWalkUpdater` + * and `MalaUpdater` to work, but is not necessary for other model-specific + * updaters. + * If such a representation is needed, child classes should also implement + * `get_unconstrained()`, `set_from_unconstrained()`, and `log_det_jac()`. + */ class BaseState { public: @@ -36,7 +38,7 @@ class BaseState { using ProtoState = bayesmix::AlgorithmState::ClusterState; - //! Returns the unconstrained representation \f$ x = B(tau) \f$ + //! Returns the unconstrained representation \f$ x = B(\tau) \f$ virtual Eigen::VectorXd get_unconstrained() { return Eigen::VectorXd(0); } //! Sets the current state as \f$ \tau = B^{-1}(in) \f$ diff --git a/src/hierarchies/likelihoods/states/fa_state.h b/src/hierarchies/likelihoods/states/fa_state.h index d9d8ef978..77e02c176 100644 --- a/src/hierarchies/likelihoods/states/fa_state.h +++ b/src/hierarchies/likelihoods/states/fa_state.h @@ -12,20 +12,25 @@ namespace State { -//! State of a Factor Analytic model -//! \f[ -//! Y_i = \Lambda \eta_i + \varepsilon -//! \f] -//! where \f$ Y_i \f$ is a `p`-dimensional vetor, \f$ \eta_i \f$ is a -//! d-dimensional one, \f$ \Lambda \f$ is a `p x d` matrix and \f$ \varepsilon -//! \f$ is an error term with mean zero and diagonal covariance matrix \f$ \psi -//! \f$. -//! -//! For faster likelihood evaluation, we store also the `cov_wood` factor and -//! the log determinant of the matrix \f$ \Lambda \Lambda^T + \psi \f$, see -//! the `compute_wood_chol_and_logdet(...)` function for more details. -//! -//! The unconstrained representation for this state is not implemented. +/** + * State of a Factor Analytic model + * + * \f[ + * Y_i = \Lambda\bm{\eta}_i + \bm{\varepsilon} + * \f] + * + * where \f$ Y_i \f$ is a \f$ p \f$-dimensional vetor, \f$ \bm{\eta}_i \f$ is a + * \f$ d \f$-dimensional one, \f$ \Lambda \f$ is a \f$ p \times d \f$ matrix + * and \f$ \bm{\varepsilon} \f$ is an error term with mean zero and diagonal + * covariance matrix \f$ \psi \f$. + * + * For faster likelihood evaluation, we store also the `cov_wood` factor and + * the log determinant of the matrix \f$ \Lambda \Lambda^T + \psi \f$, see + * the `compute_wood_chol_and_logdet(...)` function for more details. + * + * The unconstrained representation for this state is not implemented. + */ + class FA : public BaseState { public: Eigen::VectorXd mu, psi; From 8c906a8c3372b313fcf39763c186eed680836ad3 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 23:52:48 +0200 Subject: [PATCH 295/317] Fixing latex docs --- src/mixings/dirichlet_mixing.h | 29 ++++++++++------- src/mixings/logit_sb_mixing.h | 40 ++++++++++++----------- src/mixings/mixture_finite_mixing.h | 50 +++++++++++++++++------------ src/mixings/pityor_mixing.h | 31 ++++++++++-------- src/mixings/truncated_sb_mixing.h | 43 ++++++++++++++----------- 5 files changed, 110 insertions(+), 83 deletions(-) diff --git a/src/mixings/dirichlet_mixing.h b/src/mixings/dirichlet_mixing.h index fce3b4ba7..f1032624c 100644 --- a/src/mixings/dirichlet_mixing.h +++ b/src/mixings/dirichlet_mixing.h @@ -18,18 +18,23 @@ struct State { }; }; // namespace Dirichlet -//! Class that represents the EPPF induced by the Dirithclet process (DP) -//! introduced in Ferguson (1973), see also Sethuraman (1994). -//! The EPPF induced by the DP depends on a `totalmass` parameter M. -//! Given a clustering of n elements into k clusters, each with cardinality -//! \f$ n_j, j=1, ..., k \f$ the EPPF of the DP gives the following -//! probabilities for the cluster membership of the (n+1)-th observation: \f[ -//! p(\text{j-th cluster} | ...) &= n_j / (n + M) \\ -//! p(\text{new cluster} | ...) &= M / (n + M) -//! \f] -//! The state is solely composed of M, but we also store log(M) for efficiency -//! reasons. For more information about the class, please refer instead to base -//! classes, `AbstractMixing` and `BaseMixing`. +/** + * Class that represents the EPPF induced by the Dirithclet process (DP) + * introduced in Ferguson (1973), see also Sethuraman (1994). + * The EPPF induced by the DP depends on a `totalmass` parameter M. + * Given a clustering of n elements into k clusters, each with cardinality + * \f$ n_j, j=1, ..., k \f$ the EPPF of the DP gives the following + * probabilities for the cluster membership of the (n+1)-th observation: + * + * \f[ + * p(\text{j-th cluster} | ...) &= n_j / (n + M) \\ + * p(\text{new cluster} | ...) &= M / (n + M) + * \f] + * + * The state is solely composed of M, but we also store log(M) for efficiency + * reasons. For more information about the class, please refer instead to base + * classes, `AbstractMixing` and `BaseMixing`. + */ class DirichletMixing : public BaseMixing { diff --git a/src/mixings/logit_sb_mixing.h b/src/mixings/logit_sb_mixing.h index 51fd33fff..1061029c7 100644 --- a/src/mixings/logit_sb_mixing.h +++ b/src/mixings/logit_sb_mixing.h @@ -18,24 +18,28 @@ struct State { }; }; // namespace LogitSB -//! Class that represents the logit stick-breaking process indroduced in Rigon -//! and Durante (2020). -//! That is, a prior for weights (w_1,...,w_H), depending on covariates x in -//! R^p, in the H-1 dimensional unit simplex, defined as follows: -//! \f[ -//! w_1 &= v_1\\ -//! w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ -//! w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) \\ -//! v_j(x) &= 1 / exp(- <\alpha_j, x> ), for j = 1, ..., H-1 -//! \f] -//! -//! The main difference with the mentioned paper is that the authors propose a -//! Gibbs sampler in which the full conditionals are available in close form -//! thanks to a Polya-Gamma augmentation. Here instead, a Metropolis-adjusted -//! Langevin algorithm (MALA) step is used. The step-size of the MALA step must -//! be passed in the LogSBPrior Protobuf message. -//! For more information about the class, please refer instead to base classes, -//! `AbstractMixing` and `BaseMixing`. +/** + * Class that represents the logit stick-breaking process indroduced in Rigon + * and Durante (2020). + * That is, a prior for weights \f$ (w_1,\dots,w_H) \f$, depending on + * covariates \f$ x \f$ in \f$ \mathbb{R}^p \f$, in the H-1 dimensional unit + * simplex, defined as follows: + * + * \f[ + * w_1 &= v_1 \\ + * w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ + * w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) \\ + * v_j(x) &= 1 / exp(- <\alpha_j, x> ), for j = 1, ..., H-1 + * \f] + * + * The main difference with the mentioned paper is that the authors propose a + * Gibbs sampler in which the full conditionals are available in close form + * thanks to a Polya-Gamma augmentation. Here instead, a Metropolis-adjusted + * Langevin algorithm (MALA) step is used. The step-size of the MALA step must + * be passed in the LogSBPrior Protobuf message. + * For more information about the class, please refer instead to base classes, + * `AbstractMixing` and `BaseMixing`. + */ class LogitSBMixing : public BaseMixing { diff --git a/src/mixings/mixture_finite_mixing.h b/src/mixings/mixture_finite_mixing.h index 0a38752dd..67ce224dc 100644 --- a/src/mixings/mixture_finite_mixing.h +++ b/src/mixings/mixture_finite_mixing.h @@ -18,27 +18,35 @@ struct State { }; }; // namespace Mixture_Finite -//! Class that represents the Mixture of Finite Mixtures (MFM) [1] -//! The basic idea is to take usual finite mixture model with Dirichlet weights -//! and put a prior (Poisson) on the number of components. The EPPF induced by -//! MFM depends on a Dirichlet parameter 'gamma' and number \f$V_n(t)\f$, where -//! \f$V_n(t)\f$ depends on the Poisson rate parameter 'lambda'. -//! \f[ -//! V_n(t) = \sum_{k=1}^{\infty} ( k_(t)p_K(k) / (\gamma*k)^(n) ) -//! \f] -//! Given a clustering of n elements into k clusters, each with cardinality -//! \f$n_j, j=1, ..., k\f$, the EPPF of the MFM gives the following -//! probabilities for the cluster membership of the (n+1)-th observation: \f[ -//! p(\text{j-th cluster} | ...) &= (n_j + \gamma) / D \\ -//! p(\text{k+1-th cluster} | ...) &= V[k+1]/V[k] \gamma / D \\ -//! D &= n_j + \gamma / (n + \gamma * (k + V[k+1]/V[k])) -//! \f] -//! For numerical reasons each value of V is multiplied with a constant C -//! computed as the first term of the series of V_n[0]. -//! For more information about the class, please refer instead to base -//! classes, `AbstractMixing` and `BaseMixing`. -//! [1] "Mixture Models with a Prior on the Number of Components", J.W.Miller -//! and M.T.Harrison, 2015, arXiv:1502.06241v1 +/** + * Class that represents the Mixture of Finite Mixtures (MFM) [1] + * The basic idea is to take usual finite mixture model with Dirichlet weights + * and put a prior (Poisson) on the number of components. The EPPF induced by + * MFM depends on a Dirichlet parameter 'gamma' and number \f$V_n(t)\f$, where + * \f$ V_n(t) \f$ depends on the Poisson rate parameter 'lambda'. + * + * \f[ + * V_n(t) = \sum_{k=1}^{\infty} ( k_(t)p_K(k) / (\gamma*k)^(n) ) + * \f] + * + * Given a clustering of n elements into k clusters, each with cardinality + * \f$ n_j, j=1, ..., k \f$, the EPPF of the MFM gives the following + * probabilities for the cluster membership of the (n+1)-th observation: + * + * \f[ + * p(\text{j-th cluster} | ...) &= (n_j + \gamma) / D \\ + * p(\text{k+1-th cluster} | ...) &= V[k+1]/V[k] \gamma / D \\ + * D &= n_j + \gamma / (n + \gamma * (k + V[k+1]/V[k])) + * \f] + * + * For numerical reasons each value of V is multiplied with a constant C + * computed as the first term of the series of V_n[0]. + * For more information about the class, please refer instead to base + * classes, `AbstractMixing` and `BaseMixing`. + * + * [1] "Mixture Models with a Prior on the Number of Components", J.W.Miller + * and M.T.Harrison, 2015, arXiv:1502.06241v1 + */ class MixtureFiniteMixing : public BaseMixing { diff --git a/src/mixings/truncated_sb_mixing.h b/src/mixings/truncated_sb_mixing.h index 3af68249b..4d7d190fb 100644 --- a/src/mixings/truncated_sb_mixing.h +++ b/src/mixings/truncated_sb_mixing.h @@ -18,25 +18,30 @@ struct State { }; }; // namespace TruncSB -//! Class that represents a truncated stick-breaking process, as shown in -//! Ishwaran and James (2001). -//! -//! A truncated stick-breaking process is a prior for weights -//! \f$(w_1,...,w_H)\f$ in the H-1 dimensional unit simplex, and is defined as -//! follows: \f[ -//! w_1 &= v_1\\ -//! w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ -//! w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) -//! \f] -//! The \f$v_j\f$'s are called sticks and we assume them to be independently -//! distributed as \f$v_j \sim \text{Beta}(a_j, b_j)\f$. -//! -//! When \f$a_j = 1\f$ and \f$b_j = M\f$, the stick-breaking process is a -//! truncation of the stick-breaking representation of the DP. When \f$a_j = 1 -//! - d\f$ and \f$b_j = M + id \f$, it is the trunctation of a PY process. Its -//! state is composed of the weights \f$w_j\f$ in log-scale and the sticks -//! \f$v_j\f$. For more information about the class, please refer instead to -//! base classes, `AbstractMixing` and `BaseMixing`. +/** + * Class that represents a truncated stick-breaking process, as shown in + * Ishwaran and James (2001). + * + * A truncated stick-breaking process is a prior for weights + * \f$ (w_1,...,w_H) \f$ in the H-1 dimensional unit simplex, and is defined as + * follows: + * + * \f[ + * w_1 &= v_1 \\ + * w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ + * w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) + * \f] + * + * The \f$ v_j \f$'s are called sticks and we assume them to be independently + * distributed as \f$ v_j \sim \text{Beta}(a_j, b_j) \f$. + * + * When \f$ a_j = 1 \f$ and \f$ b_j = M \f$, the stick-breaking process is a + * truncation of the stick-breaking representation of the DP. + * When \f$ a_j = 1-d \f$ and \f$ b_j = M+id \f$, it is the trunctation of a PY + * process. Its state is composed of the weights \f$ w_j \f$ in log-scale and + * the sticks \f$ v_j \f$. For more information about the class, please refer + * instead to base classes, `AbstractMixing` and `BaseMixing`. + */ class TruncatedSBMixing : public BaseMixing { From a4ff9ab5c36082383ce79a5bcb3694582aa4531d Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 17 May 2022 00:00:45 +0200 Subject: [PATCH 296/317] Fix latex docs --- src/hierarchies/priors/fa_prior_model.h | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index dfec2a731..d1a0297b0 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -10,15 +10,20 @@ #include "hyperparams.h" #include "src/utils/rng.h" -//! A priormodel for the factor analyzers likelihood, that is -//! \f[ -//! \mu &\sim N_p(\tilde \mu, \psi I) \\ -//! \Lambda &\sim DL(\alpha) \\ -//! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ -//! \sigma^2_j &\sim IG(a,b) \quadq j=1,...,p -//! \f] -//! Where DL is the Dirichlet-Laplace distribution. See Bhattacharya A., Pati -//! D, Pillai N.S., Dunson D.B. (2015). JASA 110(512), 1479–1490 for details. +/** + * A priormodel for the factor analyzers likelihood, that is + * + * \f[ + * \mu &\sim N_p(\tilde{\mu}, \psi I) \\ + * \Lambda &\sim DL(\alpha) \\ + * \Sigma &= \mathrm{diag}(\sigma^2_1, \ldots, \sigma^2_p) \\ + * \sigma^2_j &\sim IG(a,b) \quad j=1,...,p + * \f] + * + * Where \f$ DL \f$ is the Dirichlet-Laplace distribution. + * See Bhattacharya A., Pati D., Pillai N.S., Dunson D.B. (2015). + * JASA 110(512), 1479–1490 for details. + */ class FAPriorModel : public BasePriorModel Date: Tue, 17 May 2022 00:01:04 +0200 Subject: [PATCH 297/317] Fix text formatting --- docs/states.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/states.rst b/docs/states.rst index b20fddaac..68de1feb1 100644 --- a/docs/states.rst +++ b/docs/states.rst @@ -12,7 +12,7 @@ Moreover, they allow to go from the constrained to the unconstrained representat Code Structure -------------- -All classes must inherit from the `BaseState` class +All classes must inherit from the ``BaseState`` class .. doxygenclass:: State::BaseState :project: bayesmix From 1685c772aa5ebdbaeb878d4c008774b46304bd41 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Fri, 13 May 2022 10:56:52 +0200 Subject: [PATCH 298/317] Docs for hierarchies --- src/hierarchies/fa_hierarchy.h | 22 ++++++++++++++++ src/hierarchies/lapnig_hierarchy.h | 29 +++++++++++++-------- src/hierarchies/likelihoods/fa_likelihood.h | 4 +-- src/hierarchies/lin_reg_uni_hierarchy.h | 26 +++++++++++++++--- src/hierarchies/nnig_hierarchy.h | 25 +++++++++++++++--- src/hierarchies/nnw_hierarchy.h | 28 ++++++++++++++++++++ src/hierarchies/nnxig_hierarchy.h | 17 ++++++++++++ 7 files changed, 131 insertions(+), 20 deletions(-) diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index 15d278023..ac6c6ad16 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -8,18 +8,37 @@ #include "src/utils/distributions.h" #include "updaters/fa_updater.h" +//! Mixture of Factor Analysers hierarchy for multivariate data. +//! +//! This class represents a hierarchical model where data are distributed +//! according to a multivariate Normal likelihood with a specific factorization +//! of the covariance function (see the `FAHierarchy` class for details). The +//! likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma +//! centering distribution (see the `FAPriorModel` class for details). That is: +//! f(x_i|mu,Sigma,Lambda) = N(mu,Sigma+Lambda*Lambda^T) +//! mu ~ N(mu0,psi*I) +//! Lambda ~ DL(alpha) +//! Sigma = diag(sig1^2,...,sigp^2) +//! sigj^2 ~ IG(a,b) for j=1,...,p +//! where Lambda is the latent score matrix (size p x d with d << p) and +//! DL(alpha) is the Laplace-Dirichlet distribution. +//! See Bhattacharya et al. (2015) for further details. + class FAHierarchy : public BaseHierarchy { public: FAHierarchy() = default; ~FAHierarchy() = default; + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::FA; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Initialize likelihood dimension to prior one like->set_dim(prior->get_dim()); @@ -37,6 +56,9 @@ class FAHierarchy } }; +//! Empirical-Bayes hyperparameters initialization for the FA HIerarchy. +//! Sets the hyperparameters in `hier` starting from the data on which the user +//! wants to fit the model. inline void set_fa_hyperparams_from_data(FAHierarchy* hier) { auto dataset_ptr = std::static_pointer_cast(hier->get_likelihood()) diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h index a7f049115..000b60bd5 100644 --- a/src/hierarchies/lapnig_hierarchy.h +++ b/src/hierarchies/lapnig_hierarchy.h @@ -1,22 +1,26 @@ #ifndef BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ #define BAYESMIX_HIERARCHIES_LAPNIG_HIERARCHY_H_ -// #include - -// #include -// #include -// #include - -// #include "algorithm_state.pb.h" -// #include "conjugate_hierarchy.h" -#include "hierarchy_id.pb.h" -// #include "hierarchy_prior.pb.h" - #include "base_hierarchy.h" +#include "hierarchy_id.pb.h" #include "likelihoods/laplace_likelihood.h" #include "priors/nxig_prior_model.h" #include "updaters/mala_updater.h" +//! Laplace Normal-InverseGamma hierarchy for univariate data. + +//! This class represents a hierarchical model where data are distributed +//! according to a laplace likelihood (see the `LaplaceLikelihood` class for +//! deatils).The likelihood parameters have a Normal x InverseGamma centering +//! distribution (see the `NxIGPriorModel` class for details). That is: +//! f(x_i|mu,lambda) = Laplace(mu,sqrt(var/2)) +//! mu ~ N(mu0,sig0^2) +//! var ~ IG(alpha0,beta0) +//! The state is composed of mean and variance (thus the scale for the Laplace +//! distribution is sqrt(var / 2)). The state hyperparameters are (mu_0, +//! sig0^2, alpha0, beta0), all scalar values. Note that this hierarchy is NOT +//! conjugate, thus the marginal distribution is not available in closed form. + class LapNIGHierarchy : public BaseHierarchy { @@ -24,12 +28,15 @@ class LapNIGHierarchy LapNIGHierarchy() = default; ~LapNIGHierarchy() = default; + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::LapNIG; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Get hypers auto hypers = prior->get_hypers(); diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h index 0c5e4938d..8a40bb6f1 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.h +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -13,8 +13,8 @@ #include "states/includes.h" //! A gaussian factor analytic likelihood, that is -//! Y ~ N_p(mu, Lambda * Lambda^T + Psi) -//! Where Lambda is a `p x d` matrix, usually d << p and `Psi` is a diagonal +//! Y ~ N_p(mu, Sigma + Lambda * Lambda^T) +//! Where Lambda is a `p x d` matrix, usually d << p and `Sigma` is a diagonal //! matrix. //! //! Parameters are stored in a `State::FA` state. diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index ae449a103..a55a51251 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -7,21 +7,37 @@ #include "priors/mnig_prior_model.h" #include "updaters/mnig_updater.h" +//! Linear regression hierarchy for univariate data. +//! +//! This class implements a dependent hierarchy which represents the classical +//! univariate Bayesian linear regression model, i.e.: +//! y_i | \beta, x_i, \sigma^2 \sim N(\beta^T x_i, sigma^2) +//! \beta | \sigma^2 \sim N(\mu, sigma^2 Lambda^{-1}) +//! \sigma^2 \sim InvGamma(a, b) +//! +//! The state consists of the `regression_coeffs` \beta, and the `var` sigma^2. +//! Lambda is called the variance-scaling factor. Note that this hierarchy is +//! conjugate, thus the marginal distribution is available in closed form. For +//! more information, please refer to parent classes: `BaseHierarchy`, +//! `UniLinRegLikelihood` for deatails on the likelihood model, and +//! `MNIGPriorModel` for details on the prior model. + class LinRegUniHierarchy : public BaseHierarchy { public: + LinRegUniHierarchy() = default; ~LinRegUniHierarchy() = default; - using BaseHierarchy::BaseHierarchy; - + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::LinRegUni; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Initialize likelihood dimension to prior one like->set_dim(prior->get_dim()); @@ -34,6 +50,10 @@ class LinRegUniHierarchy like->set_state(state); }; + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const override { auto params = hier_params->lin_reg_uni_state(); diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index b70e0eeac..36d14c94e 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -7,20 +7,33 @@ #include "priors/nig_prior_model.h" #include "updaters/nnig_updater.h" +//! Conjugate Normal Normal-InverseGamma hierarchy for univariate data. +//! +//! This class represents a hierarchical model where data are distributed +//! according to a Normal likelihood (see the `UniNormLikelihood` class for +//! details). The likelihood parameters have a Normal-InverseGamma centering +//! distribution (see the `NIGPriorModel` class for details). That is: +//! f(x_i|mu,sig^2) = N(mu,sig^2) +//! (mu,sig^2) ~ N-IG(mu0, lambda0, alpha0, beta0) +//! The state is composed of mean and variance. The state hyperparameters are +//! (mu_0, lambda0, alpha0, beta0), all scalar values. Note that this hierarchy +//! is conjugate, thus the marginal distribution is available in closed form. + class NNIGHierarchy : public BaseHierarchy { public: + NNIGHierarchy() = default; ~NNIGHierarchy() = default; - using BaseHierarchy::BaseHierarchy; - + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNIG; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Get hypers auto hypers = prior->get_hypers(); @@ -31,6 +44,10 @@ class NNIGHierarchy like->set_state(state); }; + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { auto params = hier_params->nnig_state(); @@ -41,4 +58,4 @@ class NNIGHierarchy } }; -#endif +#endif // BAYESMIX_HIERARCHIES_NNIG_HIERARCHY_H_ diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 35e5920af..ffe13b645 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -8,18 +8,36 @@ #include "src/utils/distributions.h" #include "updaters/nnw_updater.h" +//! Normal Normal-Wishart hierarchy for multivariate data. + +//! This class represents a hierarchy, i.e. a cluster, whose multivariate data +//! are distributed according to a multivariate normal likelihood (see the +//! `MultiNormLikelihood` for details). The likelihood parameters have a +//! Normal-Wishart centering distribution (see the `NWPriorModel` class for +//! details). That is: f(x_i|mu,tau) = N(mu,tau^{-1}) +//! (mu,tau) ~ NW(mu0, lambda0, tau0, nu0) +//! The state is composed of mean and precision matrix. The Cholesky factor and +//! log-determinant of the latter are also included in the container for +//! efficiency reasons. The state's hyperparameters are (mu0, lambda0, tau0, +//! nu0), which are respectively vector, scalar, matrix, and scalar. Note that +//! this hierarchy is conjugate, thus the marginal distribution is available in +//! closed form. + class NNWHierarchy : public BaseHierarchy { public: NNWHierarchy() = default; ~NNWHierarchy() = default; + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNW; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Initialize likelihood dimension to prior one like->set_dim(prior->get_dim()); @@ -34,6 +52,10 @@ class NNWHierarchy like->set_state(state); }; + //! Evaluates the log-marginal distribution of data in a single point + //! @param params Container of (prior or posterior) hyperparameter values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { HyperParams pred_params = get_predictive_t_parameters(hier_params); @@ -44,6 +66,12 @@ class NNWHierarchy logdet); } + //! Helper function that computes the predictive parameters for the + //! multivariate t distribution from the current hyperparameter values. It is + //! used to efficiently compute the log-marginal distribution of data. + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @return A `HyperParam` object with the predictive parameters HyperParams get_predictive_t_parameters(ProtoHypersPtr hier_params) const { auto params = hier_params->nnw_state(); // Compute dof and scale of marginal distribution diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h index a0de895fe..1901083c6 100644 --- a/src/hierarchies/nnxig_hierarchy.h +++ b/src/hierarchies/nnxig_hierarchy.h @@ -7,18 +7,35 @@ #include "priors/nxig_prior_model.h" #include "updaters/nnxig_updater.h" +//! Semi-conjugate Normal Normal x InverseGamma hierarchy for univariate data. +//! +//! This class represents a hierarchical model where data are distributed +//! according to a Normal likelihood (see the `UniNormLikelihood` class for +//! details). The likelihood parameters have a Normal x InverseGamma centering +//! distribution (see the `NxIGPriorModel` class for details). That is: +//! f(x_i|mu,sig^2) = N(mu,sig^2) +//! mu ~ N(mu0, sig0^2) +//! sig^2 ~ IG(alpha0, beta0) +//! The state is composed of mean and variance. The state hyperparameters are +//! (mu_0, sig0^2, alpha0, beta0), all scalar values. Note that this hierarchy +//! is NOT conjugate, meaning that the marginal distribution is not available +//! in closed form. + class NNxIGHierarchy : public BaseHierarchy { public: NNxIGHierarchy() = default; ~NNxIGHierarchy() = default; + //! Returns the Protobuf ID associated to this class bayesmix::HierarchyId get_id() const override { return bayesmix::HierarchyId::NNxIG; } + //! Sets the default updater algorithm for this hierarchy void set_default_updater() { updater = std::make_shared(); } + //! Initializes state parameters to appropriate values void initialize_state() override { // Get hypers auto hypers = prior->get_hypers(); From 6c7b378f44a62411ab0dd44905d8db9e0b173a2c Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 10:25:46 +0200 Subject: [PATCH 299/317] Improve docs --- src/hierarchies/fa_hierarchy.h | 2 +- src/hierarchies/lin_reg_uni_hierarchy.h | 41 +++++++++++++++---------- src/hierarchies/nnw_hierarchy.h | 9 +++--- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index ac6c6ad16..f98dbc7c5 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -12,7 +12,7 @@ //! //! This class represents a hierarchical model where data are distributed //! according to a multivariate Normal likelihood with a specific factorization -//! of the covariance function (see the `FAHierarchy` class for details). The +//! of the covariance matrix (see the `FAHierarchy` class for details). The //! likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma //! centering distribution (see the `FAPriorModel` class for details). That is: //! f(x_i|mu,Sigma,Lambda) = N(mu,Sigma+Lambda*Lambda^T) diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index a55a51251..50d4239c2 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -7,20 +7,25 @@ #include "priors/mnig_prior_model.h" #include "updaters/mnig_updater.h" -//! Linear regression hierarchy for univariate data. -//! -//! This class implements a dependent hierarchy which represents the classical -//! univariate Bayesian linear regression model, i.e.: -//! y_i | \beta, x_i, \sigma^2 \sim N(\beta^T x_i, sigma^2) -//! \beta | \sigma^2 \sim N(\mu, sigma^2 Lambda^{-1}) -//! \sigma^2 \sim InvGamma(a, b) -//! -//! The state consists of the `regression_coeffs` \beta, and the `var` sigma^2. -//! Lambda is called the variance-scaling factor. Note that this hierarchy is -//! conjugate, thus the marginal distribution is available in closed form. For -//! more information, please refer to parent classes: `BaseHierarchy`, -//! `UniLinRegLikelihood` for deatails on the likelihood model, and -//! `MNIGPriorModel` for details on the prior model. +/** + * Linear regression hierarchy for univariate data. + * + * This class implements a dependent hierarchy which represents the classical + * univariate Bayesian linear regression model, i.e.: + * + * \f[ + * y_i \mid \beta, x_i, \sigma^2 &\sim N(\beta^T x_i, \sigma^2) \\ + * \beta \mid \sigma^2 &\sim N(\mu, \sigma^2 \Lambda^{-1}) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * The state consists of the `regression_coeffs` \f$ \beta \f$, and the `var` + * \f$ \sigma^2 \f$. Lambda is called the variance-scaling factor. Note that + * this hierarchy is conjugate, thus the marginal distribution is available in + * closed form. For more information, please refer to the parent class + * `BaseHierarchy`, to the class `UniLinRegLikelihood` for details on the + * likelihood model and to `MNIGPriorModel` for details on the prior model. + */ class LinRegUniHierarchy : public BaseHierarchylin_reg_uni_state(); diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index ffe13b645..cb0744531 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -10,7 +10,7 @@ //! Normal Normal-Wishart hierarchy for multivariate data. -//! This class represents a hierarchy, i.e. a cluster, whose multivariate data +//! This class represents a hierarchy whose multivariate data //! are distributed according to a multivariate normal likelihood (see the //! `MultiNormLikelihood` for details). The likelihood parameters have a //! Normal-Wishart centering distribution (see the `NWPriorModel` class for @@ -53,9 +53,10 @@ class NNWHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { HyperParams pred_params = get_predictive_t_parameters(hier_params); From da1a02c01686c71b932b55a726af0c4f9deddc90 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:31:56 +0200 Subject: [PATCH 300/317] Fix doxygen warnings --- src/algorithms/marginal_algorithm.h | 9 ++++++--- src/mixings/dirichlet_mixing.h | 1 + src/mixings/mixture_finite_mixing.h | 1 + src/mixings/pityor_mixing.h | 1 + src/runtime/factory.h | 2 +- 5 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/algorithms/marginal_algorithm.h b/src/algorithms/marginal_algorithm.h index 9b1aac89d..34af2167a 100644 --- a/src/algorithms/marginal_algorithm.h +++ b/src/algorithms/marginal_algorithm.h @@ -45,9 +45,12 @@ class MarginalAlgorithm : public BaseAlgorithm { protected: //! Computes marginal contribution of the given cluster to the lpdf estimate - //! @param hier Pointer to the `Hierarchy` object representing the cluster - //! @param grid Grid of row points on which the density is to be evaluated - //! @return The marginal component of the estimate + //! @param hier Pointer to the `Hierarchy` object representing the + //! cluster + //! @param grid Grid of row points on which the density is to be + //! evaluated + //! @param covariate (Optional) covariate vectors associated to data + //! @return The marginal component of the estimate virtual Eigen::VectorXd lpdf_marginal_component( const std::shared_ptr hier, const Eigen::MatrixXd &grid, diff --git a/src/mixings/dirichlet_mixing.h b/src/mixings/dirichlet_mixing.h index 27e9416c4..bc31c5cd2 100644 --- a/src/mixings/dirichlet_mixing.h +++ b/src/mixings/dirichlet_mixing.h @@ -60,6 +60,7 @@ class DirichletMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/mixings/mixture_finite_mixing.h b/src/mixings/mixture_finite_mixing.h index 3b85a4d10..c1adc3dfa 100644 --- a/src/mixings/mixture_finite_mixing.h +++ b/src/mixings/mixture_finite_mixing.h @@ -70,6 +70,7 @@ class MixtureFiniteMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/mixings/pityor_mixing.h b/src/mixings/pityor_mixing.h index 67d9218ac..5ebec6958 100644 --- a/src/mixings/pityor_mixing.h +++ b/src/mixings/pityor_mixing.h @@ -62,6 +62,7 @@ class PitYorMixing protected: //! Returns probability mass for an old cluster (for marginal mixings only) //! @param n Total dataset size + //! @param n_clust Number of clusters //! @param log Whether to return logarithm-scale values or not //! @param propto Whether to include normalizing constants or not //! @param hier `Hierarchy` object representing the cluster diff --git a/src/runtime/factory.h b/src/runtime/factory.h index f496bd42f..8a4bbba31 100644 --- a/src/runtime/factory.h +++ b/src/runtime/factory.h @@ -77,7 +77,7 @@ class Factory { //! Adds a builder function to the storage //! @param id Identifier to associate the builder with - //! @param bulider Builder function for a specific object type + //! @param builder Builder function for a specific object type void add_builder(const Identifier &id, const Builder &builder) { storage.insert(std::make_pair(id, builder)); } From df6e90de69c92dfd9ec51cfb2ab253584bf1e6e4 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:32:55 +0200 Subject: [PATCH 301/317] Fix utils.rst bad rendering --- src/utils/cluster_utils.h | 7 ++- src/utils/distributions.h | 124 ++++++++++++++++++++------------------ src/utils/eigen_utils.h | 7 ++- src/utils/io_utils.h | 5 +- src/utils/proto_utils.h | 11 ++-- src/utils/rng.h | 3 +- 6 files changed, 84 insertions(+), 73 deletions(-) diff --git a/src/utils/cluster_utils.h b/src/utils/cluster_utils.h index 0fe166399..1930bab16 100644 --- a/src/utils/cluster_utils.h +++ b/src/utils/cluster_utils.h @@ -3,15 +3,18 @@ #include -//! This file includes some utilities for cluster estimation. These functions -//! only use Eigen ojects. +//! \file cluster_utils.h +//! The `cluster_utils.h` file includes some utilities for cluster estimation. +//! These functions only use Eigen objects. namespace bayesmix { + //! Computes the posterior similarity matrix the data Eigen::MatrixXd posterior_similarity(const Eigen::MatrixXd &alloc_chain); //! Estimates the clustering structure of the data via LS minimization Eigen::VectorXi cluster_estimate(const Eigen::MatrixXi &alloc_chain); + } // namespace bayesmix #endif // BAYESMIX_UTILS_CLUSTER_UTILS_H_ diff --git a/src/utils/distributions.h b/src/utils/distributions.h index 7fcd2bbe0..d6b7a0635 100644 --- a/src/utils/distributions.h +++ b/src/utils/distributions.h @@ -7,14 +7,15 @@ #include "algorithm_state.pb.h" -//! This file includes several useful functions related to probability -//! distributions, including categorical variables, popular multivariate -//! distributions, and distribution distances. Some of these functions make use -//! of OpenMP parallelism to achieve better efficiency. +//! @file distributions.h +//! The `distributions.h` file includes several useful functions related to +//! probability distributions, including categorical variables, popular +//! multivariate distributions, and distribution distances. Some of these +//! functions make use of OpenMP parallelism to achieve better efficiency. namespace bayesmix { -/* +/** * Returns a pseudorandom categorical random variable on the set * {start, ..., start + k} where k is the size of the given probability vector * @@ -26,64 +27,66 @@ namespace bayesmix { int categorical_rng(const Eigen::VectorXd &probas, std::mt19937_64 &rng, const int start = 0); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution parametrized by mean and precision matrix on a single point * - * @param datum Point in which to evaluate the the lpdf - * @param mean The mean of the Gaussian distribution - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @prec_logdet The logarithm of the determinant of the precision matrix - * @return The evaluation of the lpdf + * @param datum Point in which to evaluate the the lpdf + * @param mean The mean of the Gaussian distribution + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param prec_logdet The logarithm of the determinant of the precision + * matrix + * @return The evaluation of the lpdf */ double multi_normal_prec_lpdf(const Eigen::VectorXd &datum, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, const double prec_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution parametrized by mean and precision matrix on multiple points * - * @param data Grid of points (by row) on which to evaluate the lpdf - * @param mean The mean of the Gaussian distribution - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @prec_logdet The logarithm of the determinant of the precision matrix - * @return The evaluation of the lpdf + * @param data Grid of points (by row) on which to evaluate the lpdf + * @param mean The mean of the Gaussian distribution + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param prec_logdet The logarithm of the determinant of the precision + * matrix + * @return The evaluation of the lpdf */ Eigen::VectorXd multi_normal_prec_lpdf_grid(const Eigen::MatrixXd &data, const Eigen::VectorXd &mean, const Eigen::MatrixXd &prec_chol, const double prec_logdet); -/* +/** * Returns a pseudorandom multivariate normal random variable with diagonal * covariance matrix * - * @param mean The mean of the Gaussian r.v. - * @param cov_diag The diagonal covariance matrix - * @rng Random number generator - * @return multivariate normal r.v. + * @param mean The mean of the Gaussian r.v. + * @param cov_diag The diagonal covariance matrix + * @param rng Random number generator + * @return Multivariate normal r.v. */ Eigen::VectorXd multi_normal_diag_rng( const Eigen::VectorXd &mean, const Eigen::DiagonalMatrix &cov_diag, std::mt19937_64 &rng); -/* +/** * Returns a pseudorandom multivariate normal random variable parametrized * through mean and Cholesky decomposition of precision matrix * - * @param mean The mean of the Gaussian r.v. - * @prec_chol The (lower) Cholesky factor of the precision matrix - * @param rng Random number generator - * @return multivariate normal r.v. + * @param mean The mean of the Gaussian r.v. + * @param prec_chol The (lower) Cholesky factor of the precision matrix + * @param rng Random number generator + * @return Multivariate normal r.v. */ Eigen::VectorXd multi_normal_prec_chol_rng( const Eigen::VectorXd &mean, const Eigen::LLT &prec_chol, std::mt19937_64 &rng); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution with the following covariance structure: * Sigma + Lambda * Lambda^T @@ -91,7 +94,6 @@ Eigen::VectorXd multi_normal_prec_chol_rng( * y^T*(Sigma + Lambda * Lambda^T)^{-1}*y = y^T*Sigma^{-1}*y - * ||wood_factor*y||^2 * - * * @param datum Point on which to evaluate the lpdf * @param mean The mean of the Gaussian distribution * @param sigma_diag_inverse The inverse of the diagonal of Sigma matrix @@ -106,7 +108,7 @@ double multi_normal_lpdf_woodbury_chol( const Eigen::DiagonalMatrix &sigma_diag_inverse, const Eigen::MatrixXd &wood_factor, const double &cov_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Gaussian * distribution with the following covariance structure: * Sigma + Lambda * Lambda^T @@ -116,41 +118,42 @@ double multi_normal_lpdf_woodbury_chol( * computation from being O(p^3) to being O(d^3 p) which gives a substantial * speedup when p >> d * - * @param datum Point on which to evaluate the lpdf - * @param mean The mean of the Gaussian distribution + * @param datum Point on which to evaluate the lpdf + * @param mean The mean of the Gaussian distribution * @param sigma_diag The diagonal of Sigma matrix * @param lambda Rectangular matrix in Woodbury Identity - * @return The evaluation of the lpdf + * @return The evaluation of the lpdf */ double multi_normal_lpdf_woodbury(const Eigen::VectorXd &datum, const Eigen::VectorXd &mean, const Eigen::VectorXd &sigma_diag, const Eigen::MatrixXd &lambda); -/* - * Returns the log-determinant of the matrix Lambda Lambda^T + Sigma +/** + * Returns the log-determinant of the matrix \f$ \Lambda\Lambda^T + \Sigma \f$ * and the 'wood_factor', i.e. * L^{-1} * Lambda^T * Sigma^{-1}, * where L is the (lower) Cholesky factor of * I + Lambda^T * Sigma^{-1} * Lambda * - * @param sigma_dag_inverse The inverse of the diagonal matrix Sigma + * @param sigma_diag_inverse The inverse of the diagonal matrix Sigma * @param lambda The matrix Lambda */ std::pair compute_wood_chol_and_logdet( const Eigen::DiagonalMatrix &sigma_diag_inverse, const Eigen::MatrixXd &lambda); -/* +/** * Evaluates the log probability density function of a multivariate Student's t * distribution on a single point * - * @param datum Point in which to evaluate the the lpdf - * @param df The degrees of freedom of the Student's t distribution - * @param mean The mean of the Student's t distribution - * @invscale_chol The (lower) Cholesky factor of the inverse scale matrix - * @prec_logdet The logarithm of the determinant of the inverse scale matrix - * @return The evaluation of the lpdf + * @param datum Point in which to evaluate the the lpdf + * @param df The degrees of freedom of the Student's t distribution + * @param mean The mean of the Student's t distribution + * @param invscale_chol The (lower) Cholesky factor of the inverse scale matrix + * @param scale_logdet The logarithm of the determinant of the inverse scale + * matrix + * @return The evaluation of the lpdf */ double multi_student_t_invscale_lpdf(const Eigen::VectorXd &datum, const double df, @@ -158,28 +161,29 @@ double multi_student_t_invscale_lpdf(const Eigen::VectorXd &datum, const Eigen::MatrixXd &invscale_chol, const double scale_logdet); -/* +/** * Evaluates the log probability density function of a multivariate Student's t * distribution on multiple points * - * @param data Grid of points (by row) on which to evaluate the lpdf - * @param df The degrees of freedom of the Student's t distribution - * @param mean The mean of the Student's t distribution - * @invscale_chol The (lower) Cholesky factor of the inverse scale matrix - * @prec_logdet The logarithm of the determinant of the inverse scale matrix - * @return The evaluation of the lpdf + * @param data Grid of points (by row) on which to evaluate the lpdf + * @param df The degrees of freedom of the Student's t distribution + * @param mean The mean of the Student's t distribution + * @param invscale_chol The (lower) Cholesky factor of the inverse scale matrix + * @param scale_logdet The logarithm of the determinant of the inverse scale + * matrix + * @return The evaluation of the lpdf */ Eigen::VectorXd multi_student_t_invscale_lpdf_grid( const Eigen::MatrixXd &data, const double df, const Eigen::VectorXd &mean, const Eigen::MatrixXd &invscale_chol, const double scale_logdet); -/* +/** * Computes the L^2 distance between the univariate mixture of Gaussian - * densities p1(x) = \sum_{h=1}^m1 w1[h] N(x | mean1[h], var1[h]) and - * p2(x) = \sum_{h=1}^m2 w2[h] N(x | mean2[h], var2[h]) + * densities p1(x) = sum_{h=1}^m1 w1[h] N(x | mean1[h], var1[h]) and + * p2(x) = sum_{h=1}^m2 w2[h] N(x | mean2[h], var2[h]) * * The L^2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} */ double gaussian_mixture_dist(const Eigen::VectorXd &means1, const Eigen::VectorXd &vars1, @@ -188,13 +192,13 @@ double gaussian_mixture_dist(const Eigen::VectorXd &means1, const Eigen::VectorXd &vars2, const Eigen::VectorXd &weights2); -/* +/** * Computes the L^2 distance between the multivariate mixture of Gaussian - * densities p1(x) = \sum_{h=1}^m1 w1[h] N(x | mean1[h], Prec[1]^{-1}) and - * p2(x) = \sum_{h=1}^m2 w2[h] N(x | mean2[h], Prec2[h]^{-1}) + * densities p1(x) = sum_{h=1}^m1 w1[h] N(x | mean1[h], Prec[1]^{-1}) and + * p2(x) = sum_{h=1}^m2 w2[h] N(x | mean2[h], Prec2[h]^{-1}) * * The L^2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} */ double gaussian_mixture_dist(const std::vector &means1, const std::vector &precs1, @@ -203,11 +207,11 @@ double gaussian_mixture_dist(const std::vector &means1, const std::vector &precs2, const Eigen::VectorXd &weights2); -/* +/** * Computes the L^2 distance between the mixture of Gaussian * densities p(x) and q(x). These could be either univariate or multivariate. * The L2 distance amounts to - * d(p, q) = (\int (p(x) - q(x)^2 dx))^{1/2} + * d(p, q) = (int (p(x) - q(x)^2 dx))^{1/2} * * @param clus1, clus2 Cluster-specific parameters of the mix. densities * @param weights1, weights2 Weigths of the mixture densities diff --git a/src/utils/eigen_utils.h b/src/utils/eigen_utils.h index 463cb3979..2307b7f6f 100644 --- a/src/utils/eigen_utils.h +++ b/src/utils/eigen_utils.h @@ -4,9 +4,10 @@ #include #include -//! This file implements a few methods to manipulate groups of matrices, mainly -//! by joining different objects, as well as additional utilities for SPD -//! checking and grid creation. +//! @file eigen_utils.h +//! The `eigen_utils.h` file implements a few methods to manipulate groups of +//! matrices, mainly by joining different objects, as well as additional +//! utilities for SPD checking and grid creation. namespace bayesmix { //! Concatenates a vector of Eigen matrices along the rows diff --git a/src/utils/io_utils.h b/src/utils/io_utils.h index 89b830e34..b9c4231a6 100644 --- a/src/utils/io_utils.h +++ b/src/utils/io_utils.h @@ -3,8 +3,9 @@ #include -//! This file implements basic input-output utilities for Eigen matrices from -//! and to text files. +//! @file io_utils.h +//! The `io_utils.h` file implements basic input-output utilities for Eigen +//! matrices from and to text files. namespace bayesmix { //! Checks whether the given file is available for writing diff --git a/src/utils/proto_utils.h b/src/utils/proto_utils.h index cd5466d0a..cb8c3333d 100644 --- a/src/utils/proto_utils.h +++ b/src/utils/proto_utils.h @@ -5,11 +5,12 @@ #include "matrix.pb.h" -//! This file implements a few useful functions to manipulate Protobuf objects. -//! For instance, this library implements its own version of vectors and -//! matrices, and the functions implemented here convert from these types to -//! the Eigen ones and viceversa. One can also read a Protobuf from a text -//! file. This is mostly useful for algorithm configuration files. +//! @file proto_utils.h +//! The `proto_utils.h` file implements a few useful functions to manipulate +//! Protobuf objects. For instance, this library implements its own version of +//! vectors and matrices, and the functions implemented here convert from these +//! types to the Eigen ones and viceversa. One can also read a Protobuf from a +//! text file. This is mostly useful for algorithm configuration files. namespace bayesmix { diff --git a/src/utils/rng.h b/src/utils/rng.h index 8fc6e6264..fc71a2f69 100644 --- a/src/utils/rng.h +++ b/src/utils/rng.h @@ -3,7 +3,8 @@ #include -//! Simple Random Number Generation class wrapper. +//! @file rng.h +//! The `rng.h` file defines a simple Random Number Generation class wrapper. //! This class wraps the C++ standard RNG object and allows the use of any RNG //! seed. It is implemented as a singleton, so that every object used in the //! library has access to the same exact RNG engine. From 1ec91e2e8d6ace9edfd59dc3670f25d4475260b8 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:33:59 +0200 Subject: [PATCH 302/317] Uniform docs for likelihoods (+ latex in doxygen) --- src/hierarchies/likelihoods/fa_likelihood.h | 25 +++++++++------- .../likelihoods/laplace_likelihood.h | 26 +++++++++------- .../likelihoods/multi_norm_likelihood.h | 20 ++++++++----- .../likelihoods/uni_lin_reg_likelihood.h | 30 ++++++++++--------- .../likelihoods/uni_norm_likelihood.h | 20 ++++++++----- 5 files changed, 70 insertions(+), 51 deletions(-) diff --git a/src/hierarchies/likelihoods/fa_likelihood.h b/src/hierarchies/likelihoods/fa_likelihood.h index 8a40bb6f1..3e2e08e40 100644 --- a/src/hierarchies/likelihoods/fa_likelihood.h +++ b/src/hierarchies/likelihoods/fa_likelihood.h @@ -12,16 +12,21 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A gaussian factor analytic likelihood, that is -//! Y ~ N_p(mu, Sigma + Lambda * Lambda^T) -//! Where Lambda is a `p x d` matrix, usually d << p and `Sigma` is a diagonal -//! matrix. -//! -//! Parameters are stored in a `State::FA` state. -//! We store as summary statistics the sum of the y_i's, but it is -//! not sufficient for all the updates involved. Therefore, all the -//! observations allocated to a cluster are processed when computing the -//! cluster lpdf. +/** + * A gaussian factor analytic likelihood, using the `State::FA` state. + * Represents the model: + * + * \f[ + * \bm{y}_1,\dots,\bm{y}_k \stackrel{\small\mathrm{iid}}{\sim} N_p(\bm{\mu}, + * \Sigma + \Lambda\Lambda^T), \f] + * + * where Lambda is a \f$ p \times d \f$ matrix, usually \f$ d << p \f$ and \f$ + * \Sigma \f$ is a diagonal matrix. Parameters are stored in a `State::FA` + * state. We store as summary statistics the sum of the \f$ \bm{y}_i \f$'s, but + * it is not sufficient for all the updates involved. Therefore, all the + * observations allocated to a cluster are processed when computing the + * cluster lpdf. + */ class FALikelihood : public BaseLikelihood { public: diff --git a/src/hierarchies/likelihoods/laplace_likelihood.h b/src/hierarchies/likelihoods/laplace_likelihood.h index 6b9196b8a..9d4c25128 100644 --- a/src/hierarchies/likelihoods/laplace_likelihood.h +++ b/src/hierarchies/likelihoods/laplace_likelihood.h @@ -11,17 +11,21 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A univariate laplace likelihood -//! -//! Represents the model: -//! y_i ~ Laplace(mu, var) -//! where mu is the mean and center of the distribution -//! and var is the variance. The scale is then sqrt(var / 2) -//! These parameters are stored in a `State::UniLS` state -//! -//! Since the Laplace likelihood does not have sufficient statistics -//! other than the whole sample, the `update_sum_stats` method -//! does nothing. +/** + * A univariate Laplace likelihood, using the `State::UniLS` state. Represents + * the model: + * + * \f[ + * y_1,\dots,y_k \mid \mu, \sigma^2 \stackrel{\small\mathrm{iid}}{\sim} + * Laplace(\mu,\sigma^2), \f] + * + * where \f$ \mu \f$ is the mean and center of the distribution + * and \f$ \sigma^2 \f$ is the variance. The scale parameter \f$ \lambda \f$ is + * then \f$ \sqrt{\sigma^2/2} \f$. These parameters are stored in a + * `State::UniLS` state. Since the Laplace likelihood does not have sufficient + * statistics other than the whole sample, the `update_sum_stats()` method does + * nothing. + */ class LaplaceLikelihood : public BaseLikelihood { diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.h b/src/hierarchies/likelihoods/multi_norm_likelihood.h index 13f338faa..6e249a75e 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.h +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.h @@ -12,14 +12,18 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A multivariate normal likelihood -//! -//! Represents the model: -//! y_1, ..., y_m ~ N(mu, Cov) -//! where (mu, Cov) are stored in a `State::MultiLS` state -//! -//! The sufficient statistics stored are the sum of the y_i's -//! and the sum of y_i^T y_i. +/** + * A multivariate normal likelihood, using the `State::MultiLS` state. + * Represents the model: + * + * \f[ + * \bm{y}_1,\dots, \bm{y}_k \stackrel{\small\mathrm{iid}}{\sim} + * N_p(\bm{\mu}, \Sigma), \f] + * + * where \f$ (\bm{\mu}, \Sigma) \f$ are stored in a `State::MultiLS` state. + * The sufficient statistics stored are the sum of the \f$ \bm{y}_i \f$'s + * and the sum of \f$ \bm{y}_i^T \bm{y}_i \f$. + */ class MultiNormLikelihood : public BaseLikelihood { diff --git a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h index 96600d4dc..bb9e55687 100644 --- a/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h +++ b/src/hierarchies/likelihoods/uni_lin_reg_likelihood.h @@ -11,16 +11,18 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A scalar linear regression model -//! -//! Represents the model: -//! y_i ~ N(x_i^T * reg_coeffs, var) -//! where (reg_coeffs, var) are stored in a `State::UniLinRegLS` state -//! -//! The sufficient statistics stored are the -//! 1) sum of y_i^2 -//! 2) sum of x_i^T x_i -//! 3) sum of y_i x_i^T +/** + * A scalar linear regression model, using the `State::UniLinRegLS` state. + * Represents the model: + * + * \f[ + * y_i \mid \bm{x}_i, \bm{\beta}, \sigma^2 + * \stackrel{\small\mathrm{ind}}{\sim} N(\bm{x}_i^T\bm{\beta},\sigma^2), \f] + * + * where \f$ (\bm{\beta}, \sigma^2) \f$ are stored in a `State::UniLinRegLS` + * state. The sufficient statistics stored are the sum of \f$ y_i^2 \f$, the + * sum of \f$ \bm{x}_i^T \bm{x}_i \f$ and the sum of \f$ y_i \bm{x}_i^T \f$. + */ class UniLinRegLikelihood : public BaseLikelihood { @@ -48,13 +50,13 @@ class UniLinRegLikelihood const Eigen::RowVectorXd &covariate, bool add) override; - //! Dimension of the coefficients vector + // Dimension of the coefficients vector unsigned int dim; - //! Represents pieces of y^t y + // Represents pieces of y^t y double data_sum_squares; - //! Represents pieces of X^T X + // Represents pieces of X^T X Eigen::MatrixXd covar_sum_squares; - //! Represents pieces of X^t y + // Represents pieces of X^t y Eigen::VectorXd mixed_prod; }; diff --git a/src/hierarchies/likelihoods/uni_norm_likelihood.h b/src/hierarchies/likelihoods/uni_norm_likelihood.h index eb17f6853..e278a3635 100644 --- a/src/hierarchies/likelihoods/uni_norm_likelihood.h +++ b/src/hierarchies/likelihoods/uni_norm_likelihood.h @@ -11,14 +11,18 @@ #include "base_likelihood.h" #include "states/includes.h" -//! A univariate normal likelihood, using the `State::UniLS` state. -//! -//! Represents the model: -//! y_1, ..., y_m ~ N(mu, var) -//! where (mu, var) are stored in a `State::UniLS` state -//! -//! The sufficient statistics stored are the sum of the y_i's -//! and the sum of y_i^2. +/** + * A univariate normal likelihood, using the `State::UniLS` state. Represents + * the model: + * + * \f[ + * y_1, \dots, y_k \mid \mu, \sigma^2 \stackrel{\small\mathrm{iid}}{\sim} + * N(\mu, \sigma^2), \f] + * + * where \f$ (\mu, \sigma^2) \f$ are stored in a `State::UniLS` state. + * The sufficient statistics stored are the sum of the \f$ y_i \f$'s and the + * sum of \f$ y_i^2 \f$. + */ class UniNormLikelihood : public BaseLikelihood { From 2cd11f8f2a5b3cc20c6cc1f6214698cb9a4556b4 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:34:44 +0200 Subject: [PATCH 303/317] Fix doxygen warnings --- src/hierarchies/base_hierarchy.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index 3206c4967..ad71a0c93 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -343,9 +343,10 @@ class BaseHierarchy : public AbstractHierarchy { virtual void initialize_state() = 0; //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf virtual double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const { if (!is_conjugate()) { @@ -358,10 +359,11 @@ class BaseHierarchy : public AbstractHierarchy { } //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @param covariate Covariate vector associated to datum - //! @return The evaluation of the lpdf + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @param datum Point which is to be evaluated + //! @param covariate Covariate vector associated to datum + //! @return The evaluation of the lpdf virtual double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum, const Eigen::RowVectorXd &covariate) const { From 46dfb10065aeba88bd056c039e3089a996aade0c Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:35:28 +0200 Subject: [PATCH 304/317] Add latex to doxygen (ONGOING) --- src/hierarchies/abstract_hierarchy.h | 74 ++++++++++++++----------- src/hierarchies/lin_reg_uni_hierarchy.h | 8 +-- src/hierarchies/nnig_hierarchy.h | 7 ++- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/src/hierarchies/abstract_hierarchy.h b/src/hierarchies/abstract_hierarchy.h index 5bd4387e0..947f362cf 100644 --- a/src/hierarchies/abstract_hierarchy.h +++ b/src/hierarchies/abstract_hierarchy.h @@ -16,38 +16,48 @@ #include "src/hierarchies/updaters/abstract_updater.h" #include "src/utils/rng.h" -//! Abstract base class for a hierarchy object. -//! This class is the basis for a curiously recurring template pattern (CRTP) -//! for `Hierarchy` objects, and is solely composed of interface functions for -//! derived classes to use. For more information about this pattern, as well -//! the list of methods required for classes in this inheritance tree, please -//! refer to the README.md file included in this folder. - -//! This abstract class represents a Bayesian hierarchical model: -//! x_1, ..., x_n \sim f(x | \theta) -//! theta \sim G -//! A Hierarchy object can compute the following quantities: -//! 1- the likelihood log-probability density function -//! 2- the prior predictive probability: \int_\Theta f(x | theta) G(d\theta) -//! (for conjugate models only) -//! 3- the posterior predictive probability -//! \int_\Theta f(x | theta) G(d\theta | x_1, ..., x_n) -//! (for conjugate models only) -//! Moreover, the Hierarchy knows how to sample from the full conditional of -//! theta, possibly in an approximate way. -//! -//! In the context of our Gibbs samplers, an hierarchy represents the parameter -//! value associated to a certain cluster, and also knows which observations -//! are allocated to that cluster. -//! Moreover, hyperparameters and (possibly) hyperpriors associated to them can -//! be shared across multiple Hierarchies objects via a shared pointer. -//! In conjunction with a single `Mixing` object, a collection of `Hierarchy` -//! objects completely defines a mixture model, and these two parts can be -//! chosen independently of each other. -//! Communication with other classes, as well as storage of some relevant -//! values, is performed via appropriately defined Protobuf messages (see for -//! instance the proto/ls_state.proto and proto/hierarchy_prior.proto files) -//! and their relative class methods. +/** + * Abstract base class for a hierarchy object. + * This class is the basis for a curiously recurring template pattern (CRTP) + * for `Hierarchy` objects, and is solely composed of interface functions for + * derived classes to use. For more information about this pattern, as well + * the list of methods required for classes in this inheritance tree, please + * refer to the README.md file included in this folder. + * + * This abstract class represents a Bayesian hierarchical model: + * + * \f[ + * x_1,\dots,x_n &\sim f(x \mid \theta) \\ + * \theta &\sim G + * \f] + * + * A Hierarchy object can compute the following quantities: + * + * 1. the likelihood log-probability density function + * 2. the prior predictive probability: \f$ \int_\Theta f(x \mid \theta) + * G(d\theta) \f$ (for conjugate models only) + * 3. the posterior predictive probability + * \f$ \int_\Theta f(x \mid \theta) G(d\theta \mid x_1, ..., x_n) \f$ + * (for conjugate models only) + * + * Moreover, the Hierarchy knows how to sample from the full conditional of + * \f$ \theta \f$, possibly in an approximate way. + * + * In the context of our Gibbs samplers, an hierarchy represents the parameter + * value associated to a certain cluster, and also knows which observations + * are allocated to that cluster. + * + * Moreover, hyperparameters and (possibly) hyperpriors associated to them can + * be shared across multiple Hierarchies objects via a shared pointer. + * In conjunction with a single `Mixing` object, a collection of `Hierarchy` + * objects completely defines a mixture model, and these two parts can be + * chosen independently of each other. + * + * Communication with other classes, as well as storage of some relevant + * values, is performed via appropriately defined Protobuf messages (see for + * instance the `proto/ls_state.proto` and `proto/hierarchy_prior.proto` files) + * and their relative class methods. + */ class AbstractHierarchy { public: diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index 50d4239c2..bd208ef01 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -20,10 +20,10 @@ * \f] * * The state consists of the `regression_coeffs` \f$ \beta \f$, and the `var` - * \f$ \sigma^2 \f$. Lambda is called the variance-scaling factor. Note that - * this hierarchy is conjugate, thus the marginal distribution is available in - * closed form. For more information, please refer to the parent class - * `BaseHierarchy`, to the class `UniLinRegLikelihood` for details on the + * \f$ \sigma^2 \f$. \f$ \Lambda \f$ is called the variance-scaling factor. + * Note that this hierarchy is conjugate, thus the marginal distribution is + * available in closed form. For more information, please refer to the parent + * class `BaseHierarchy`, to the class `UniLinRegLikelihood` for details on the * likelihood model and to `MNIGPriorModel` for details on the prior model. */ diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index 36d14c94e..4d1db6200 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -45,9 +45,10 @@ class NNIGHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param params Container of (prior or posterior) hyperparameter values - //! @param datum Point which is to be evaluated - //! @return The evaluation of the lpdf + //! @param hier_params Container of (prior or posterior) hyperparameter + //! values + //! @param datum Point which is to be evaluated + //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { auto params = hier_params->nnig_state(); From fbde5e99c0e84d0e13463d1436a1a3f66ec114b9 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:36:36 +0200 Subject: [PATCH 305/317] Improved rst files (ONGOING) --- docs/Doxyfile.in | 4 +-- docs/hierarchies.rst | 79 ++++++++++++++++++++++++---------------- docs/index.rst | 11 ++++-- docs/likelihoods.rst | 85 ++++++++++++++++++++++++++++++++++++++++++++ docs/utils.rst | 13 ++++--- 5 files changed, 152 insertions(+), 40 deletions(-) create mode 100644 docs/likelihoods.rst diff --git a/docs/Doxyfile.in b/docs/Doxyfile.in index 30cda29e9..6b163fa99 100644 --- a/docs/Doxyfile.in +++ b/docs/Doxyfile.in @@ -1819,7 +1819,7 @@ PAPER_TYPE = a4 # If left blank no extra packages will be included. # This tag requires that the tag GENERATE_LATEX is set to YES. -EXTRA_PACKAGES = +EXTRA_PACKAGES = bm # The LATEX_HEADER tag can be used to specify a personal LaTeX header for the # generated LaTeX document. The header should contain everything until the first @@ -1893,7 +1893,7 @@ USE_PDFLATEX = YES # The default value is: NO. # This tag requires that the tag GENERATE_LATEX is set to YES. -LATEX_BATCHMODE = NO +LATEX_BATCHMODE = YES # If the LATEX_HIDE_INDICES tag is set to YES then doxygen will not include the # index chapters (such as File Index, Compound Index, etc.) in the output. diff --git a/docs/hierarchies.rst b/docs/hierarchies.rst index 21c6fe1d7..7f520588a 100644 --- a/docs/hierarchies.rst +++ b/docs/hierarchies.rst @@ -6,24 +6,23 @@ Hierarchies In our algorithms, we store a vector of hierarchies, each of which represent a parameter :math:`\theta_h`. The hierarchy implements all the methods needed to update :math:`\theta_h`: sampling from the prior distribution :math:`P_0`, the full-conditional distribution (given the data {:math:`y_i` such that :math:`c_i = h`} ) and so on. - ------------------------- Main operations performed ------------------------- A hierarchy must be able to perform the following operations: -1. Sample from the prior distribution: generate :math:`\theta_h \sim P_0` [``sample_prior``] -2. Sample from the 'full conditional' distribution: generate theta_h from the distribution :math:`p(\theta_h \mid \cdots ) \propto P_0(\theta_h) \prod_{i: c_i = h} k(y_i | \theta_h)` [``sample_full_conditional``] -3. Update the hyperparameters involved in :math:`P_0` [``update_hypers``] -4. Evaluate the likelihood in one point, i.e. :math:`k(x | \theta_h)` for theta_h the current value of the parameters [``like_lpdf``] -5. When :math:`k` and :math:`P_0` are conjugate, we must also be able to compute the marginal/prior predictive distribution in one point, i.e. :math:`m(x) = \int k(x | \theta) P_0(d\theta)`, and the conditional predictive distribution :math:`m(x | \textbf{y} ) = \int k(x | \theta) P_0(d\theta | \{y_i: c_i = h\})` [``prior_pred_lpdf``, ``conditional_pred_lpdf``] +a. Sample from the prior distribution: generate :math:`\theta_h \sim P_0` [``sample_prior``] +b. Sample from the 'full conditional' distribution: generate theta_h from the distribution :math:`p(\theta_h \mid \cdots ) \propto P_0(\theta_h) \prod_{i: c_i = h} k(y_i | \theta_h)` [``sample_full_conditional``] +c. Update the hyperparameters involved in :math:`P_0` [``update_hypers``] +d. Evaluate the likelihood in one point, i.e. :math:`k(x | \theta_h)` for theta_h the current value of the parameters [``like_lpdf``] +e. When :math:`k` and :math:`P_0` are conjugate, we must also be able to compute the marginal/prior predictive distribution in one point, i.e. :math:`m(x) = \int k(x | \theta) P_0(d\theta)`, and the conditional predictive distribution :math:`m(x | \textbf{y} ) = \int k(x | \theta) P_0(d\theta | \{y_i: c_i = h\})` [``prior_pred_lpdf``, ``conditional_pred_lpdf``] Moreover, the following utilities are needed: -6. write the current state :math:`\theta_h` into a appropriately defined Protobuf message [``write_state_to_proto``] -7. restore theta_h from a given Protobuf message [``set_state_from_proto``] -8. write the values of the hyperparameters in :math:`P_0` to a Protobuf message [``write_hypers_to_proto``] +f. write the current state :math:`\theta_h` into a appropriately defined Protobuf message [``write_state_to_proto``] +g. restore theta_h from a given Protobuf message [``set_state_from_proto``] +h. write the values of the hyperparameters in :math:`P_0` to a Protobuf message [``write_hypers_to_proto``] In each hierarchy, we also keep track of which data points are allocated to the hierarchy. @@ -42,21 +41,26 @@ The code thus composes of: a virtual class defining the API, a template base cla The class ``AbstractHierarchy`` defines the API, i.e. all the methods that need to be called from outside of a ``Hierarchy`` class. A template class ``BaseHierarchy`` inherits from ``AbstractHierarchy`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. -Instead, child classes must implement: +.. toctree:: + :maxdepth: 1 + :caption: API: hierarchies submodules -1. ``like_lpdf``: evaluates :math:`k(x | \theta_h)` -2. ``marg_lpdf``: evaluates m(x) given some parameters :math:`\theta_h` (could be both the hyperparameters in :math:`P_0` or the paramters given by the full conditionals) -3. ``draw``: samples from :math:`P_0` given the parameters -4. ``clear_summary_statistics``: clears all the summary statistics -5. ``update_hypers``: performs the update of parameters in :math:`P_0` given all the :math:`\theta_h` (passed as a vector of protobuf Messages) -6. ``initialize_state``: initializes the current :math:`\theta_h` given the hyperparameters in :math:`P_0` -7. ``initialize_hypers``: initializes the hyperparameters in :math:`P_0` given their hyperprior -8. ``update_summary_statistics``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy -9. ``get_posterior_parameters``: returns the paramters of the full conditional distribution **possible only when** :math:`P_0` **and** :math:`k` **are conjugate** -10. ``set_state_from_proto`` -11. ``write_state_to_proto`` -12. ``write_hypers_to_proto`` + likelihoods +Instead, child classes must implement: + +a. ``like_lpdf``: evaluates :math:`k(x | \theta_h)` +b. ``marg_lpdf``: evaluates m(x) given some parameters :math:`\theta_h` (could be both the hyperparameters in :math:`P_0` or the paramters given by the full conditionals) +c. ``draw``: samples from :math:`P_0` given the parameters +d. ``clear_summary_statistics``: clears all the summary statistics +e. ``update_hypers``: performs the update of parameters in :math:`P_0` given all the :math:`\theta_h` (passed as a vector of protobuf Messages) +f. ``initialize_state``: initializes the current :math:`\theta_h` given the hyperparameters in :math:`P_0` +g. ``initialize_hypers``: initializes the hyperparameters in :math:`P_0` given their hyperprior +h. ``update_summary_statistics``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy +i. ``get_posterior_parameters``: returns the paramters of the full conditional distribution **possible only when** :math:`P_0` **and** :math:`k` **are conjugate** +j. ``set_state_from_proto`` +k. ``write_state_to_proto`` +l. ``write_hypers_to_proto`` Note that not all of these members are declared virtual in ``AbstractHierarchy`` or ``BaseHierarchy``: this is because virtual members are only the ones that must be called from outside the ``Hierarchy``, the other ones are handled via CRTP. Not having them virtual saves a lot of lookups in the vtables. The ``BaseHierarchy`` class takes 4 template parameters: @@ -66,12 +70,11 @@ The ``BaseHierarchy`` class takes 4 template parameters: 3. ``Hyperparams`` is usually a struct representing the parameters in :math:`P_0` 4. ``Prior`` must be a protobuf object encoding the prior parameters. +.. Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models. -Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models. - -------- -Classes -------- +---------------- +Abstract Classes +---------------- .. doxygenclass:: AbstractHierarchy :project: bayesmix @@ -79,9 +82,11 @@ Classes .. doxygenclass:: BaseHierarchy :project: bayesmix :members: -.. doxygenclass:: ConjugateHierarchy - :project: bayesmix - :members: + +--------------------------------- +Classes for Conjugate Hierarchies +--------------------------------- + .. doxygenclass:: NNIGHierarchy :project: bayesmix :members: @@ -91,3 +96,17 @@ Classes .. doxygenclass:: LinRegUniHierarchy :project: bayesmix :members: + +------------------------------------- +Classes for Non-Conjugate Hierarchies +------------------------------------- + +.. doxygenclass:: NNxIGHierarchy + :project: bayesmix + :members: +.. doxygenclass:: LapNIGHierarchy + :project: bayesmix + :members: +.. doxygenclass:: FAHierarchy + :project: bayesmix + :members: diff --git a/docs/index.rst b/docs/index.rst index 59d19db44..271387414 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,7 +32,8 @@ There are currently three submodules to the ``bayesmix`` library, represented by Further, we employ Protocol buffers for several purposes, including serialization. The list of all protos with their docs is available in the ``protos`` link below. .. toctree:: - :maxdepth: 1 + :maxdepth: 2 + :titlesonly: :caption: API: library submodules algorithms @@ -43,11 +44,15 @@ Further, we employ Protocol buffers for several purposes, including serializatio utils - Tutorials ========= -:doc:`tutorial` +.. toctree:: + :maxdepth: 1 + + tutorial + +.. :doc:`tutorial` Python interface diff --git a/docs/likelihoods.rst b/docs/likelihoods.rst new file mode 100644 index 000000000..78ce5294d --- /dev/null +++ b/docs/likelihoods.rst @@ -0,0 +1,85 @@ +bayesmix/hierarchies/likelihoods + +Likelihoods +=========== + +The ``Likelihood`` sub-module represents the likelihood we have assumed for the data in a given cluster. Each ``Likelihood`` class represents the sampling model + +.. math:: + y_1, \ldots, y_k \mid \bm{\tau} \stackrel{\small\mathrm{iid}}{\sim} f(\cdot \mid \bm{\tau}) + +for a specific choice of the probability density function :math:`f`. + +------------------------- +Main operations performed +------------------------- + +A likelihood object must be able to perform the following operations: + +a. First of all, we require the \code{lpdf()} and \code{lpdf\_grid()} methods, which simply evaluate the loglikelihood in a given point or in a grid of points (also in case of a \emph{dependent} likelihood, i.e., with covariates associated to each observation) [``lpdf()`` and ``lpdf_grid``] +b. In case you want to rely on a Metropolis-like updater, the likelihood needs to evaluation of the likelihood of the whole cluster starting from the vector of unconstrained parameters [``cluster_lpdf_from_unconstrained()``]. Observe that the ``AbstractLikelihood`` class provides two such methods, one returning a ``double`` and one returning a ``stan::math::var``. The latter is used to automatically compute the gradient of the likelihood via Stan's automatic differentiation, if needed. In practice, users do not need to implement both methods separately and can implement only one templated method +c. manage the insertion and deletion of a datum in the cluster [``add_datum`` and ``remove_datum``] +d. update the summary statistics associated to the likelihood [``update_summary_statistics``]. Summary statistics (when available) are used to evaluate the likelihood function on the whole cluster, as well as to perform the posterior updates of :math:`\bm{\tau}`. This usually gives a substantial speedup + +-------------- +Code structure +-------------- + +In principle, the ``Likelihood`` classes are responsible only of evaluating the log-likelihood function given a specific choice of parameters :math:`\bm{\tau}`. +Therefore, a simple inheritance structure would seem appropriate. However, the nature of the parameters :math:`\bm{\tau}` can be very different across different models (think for instance of the difference between the univariate normal and the multivariate normal paramters). As such, we employ CRTP to manage the polymorphic nature of ``Likelihood`` classes. + +The class ``AbstractLikelihood`` defines the API, i.e. all the methods that need to be called from outside of a ``Likelihood`` class. +A template class ``BaseLikelihood`` inherits from ``AbstractLikelihood`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. + +Instead, child classes **must** implement: + +a. ``compute_lpdf``: evaluates :math:`k(x \mid \theta_h)` +b. ``update_sum_stats``: updates the summary statistics when an observation is allocated or de-allocated from the hierarchy +c. ``clear_summary_statistics``: clears all the summary statistics +d. ``is_dependent``: defines if the given likelihood depends on covariates +e. ``is_multivariate``: defines if the given likelihood is for multivariate data + +In case the likelihood needs to be used in a Metropolis-like updater, child classes **should** also implement: + +f. ``cluster_lpdf_from_unconstrained``: evaluates :math:`\prod_{i: c_i = h} k(x_i \mid \tilde{\theta}_h)`, where :math:`\tilde{\theta}_h` is the vector of unconstrained parameters. + +---------------- +Abstract Classes +---------------- + +.. doxygenclass:: AbstractLikelihood + :project: bayesmix + :members: +.. doxygenclass:: BaseLikelihood + :project: bayesmix + :members: + +---------------------------------- +Classes for Univariate Likelihoods +---------------------------------- + +.. doxygenclass:: UniNormLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: UniLinRegLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: LaplaceLikelihood + :project: bayesmix + :members: + :protected-members: + +------------------------------------ +Classes for Multivariate Likelihoods +------------------------------------ + +.. doxygenclass:: MultiNormLikelihood + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: FALikelihood + :project: bayesmix + :members: + :protected-members: diff --git a/docs/utils.rst b/docs/utils.rst index 89c7e4533..45420f464 100644 --- a/docs/utils.rst +++ b/docs/utils.rst @@ -18,17 +18,20 @@ Distribution-related utilities :project: bayesmix ---------------------------------------------- -``Eigen`` input-output and matrix manipulation +``Eigen`` matrix manipulation utilities ---------------------------------------------- .. doxygenfile:: eigen_utils.h :project: bayesmix - :members: + +-------------------------------- +``Eigen`` input-output utilities +-------------------------------- .. doxygenfile:: io_utils.h :project: bayesmix -------------------------- -``protobuf`` input-output -------------------------- +----------------------------------- +``protobuf`` input-output utilities +----------------------------------- .. doxygenfile:: proto_utils.h :project: bayesmix From 9b13c4fbbb3fb2bdc73d7c7a6a9ef08b68fdc54e Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:37:10 +0200 Subject: [PATCH 306/317] Switch-off docker (TO RESTORE) --- docs/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/CMakeLists.txt b/docs/CMakeLists.txt index 5981d2b7a..3318965a8 100644 --- a/docs/CMakeLists.txt +++ b/docs/CMakeLists.txt @@ -82,6 +82,6 @@ install(DIRECTORY ${SPHINX_BUILD} DESTINATION ${CMAKE_INSTALL_DOCDIR}) add_custom_target(document_bayesmix) -add_dependencies(document_bayesmix document_protos) +# add_dependencies(document_bayesmix document_protos) add_dependencies(document_bayesmix Doxygen) add_dependencies(document_bayesmix Sphinx) From cb9d69c6b0a7be17e4802bfcee48cdb1915513c1 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Mon, 16 May 2022 15:37:22 +0200 Subject: [PATCH 307/317] Comment line --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 3501007dd..af16e0bdc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,7 +57,7 @@ def configureDoxyfile(input_dir, output_dir): html_theme = 'haiku' -html_static_path = ['_static'] +# html_static_path = ['_static'] highlight_language = 'cpp' From 6e264ac25d62259bbf9b3cc2525486084cb7984e Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 16:53:58 +0200 Subject: [PATCH 308/317] addedd missing classes --- docs/algorithms.rst | 3 + docs/hierarchies.rst | 22 +- docs/likelihoods.rst | 2 +- docs/mixings.rst | 3 + docs/protos.html | 6095 ++++++++++++++++++++++-------------------- 5 files changed, 3230 insertions(+), 2895 deletions(-) diff --git a/docs/algorithms.rst b/docs/algorithms.rst index c61efc257..635800167 100644 --- a/docs/algorithms.rst +++ b/docs/algorithms.rst @@ -20,6 +20,9 @@ Algorithms .. doxygenclass:: Neal8Algorithm :project: bayesmix :members: +.. doxygenclass:: SplitMergeAlgorithm + :project: bayesmix + :members: .. doxygenclass:: ConditionalAlgorithm :project: bayesmix :members: diff --git a/docs/hierarchies.rst b/docs/hierarchies.rst index 7f520588a..f5272c1fd 100644 --- a/docs/hierarchies.rst +++ b/docs/hierarchies.rst @@ -6,6 +6,21 @@ Hierarchies In our algorithms, we store a vector of hierarchies, each of which represent a parameter :math:`\theta_h`. The hierarchy implements all the methods needed to update :math:`\theta_h`: sampling from the prior distribution :math:`P_0`, the full-conditional distribution (given the data {:math:`y_i` such that :math:`c_i = h`} ) and so on. +In BayesMix, each choice of :math:`G_0` is implemented in a different ``PriorModel`` object and each choice of :math:k(\cdot \mid \cdot)` in a ``Likelihood`` object, so that it is straightforward to create a new ``Hierarchy`` using one of the already implemented priors or likelihoods. +The sampling from the full conditional of :math:`\theta_h` is performed in an ``Updater`` class. +`State` classes are used to store parameters ``\theta_h`s of every mixture component. +Their main purpose is to handle serialization and de-serialization of the state + +.. toctree:: + :maxdepth: 1 + :caption: API: hierarchies submodules + + likelihoods + prior_models + updaters + states + + ------------------------- Main operations performed ------------------------- @@ -41,12 +56,6 @@ The code thus composes of: a virtual class defining the API, a template base cla The class ``AbstractHierarchy`` defines the API, i.e. all the methods that need to be called from outside of a ``Hierarchy`` class. A template class ``BaseHierarchy`` inherits from ``AbstractHierarchy`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. -.. toctree:: - :maxdepth: 1 - :caption: API: hierarchies submodules - - likelihoods - Instead, child classes must implement: a. ``like_lpdf``: evaluates :math:`k(x | \theta_h)` @@ -72,6 +81,7 @@ The ``BaseHierarchy`` class takes 4 template parameters: .. Finally, a ``ConjugateHierarchy`` takes care of the implementation of some methods that are specific to conjugate models. + ---------------- Abstract Classes ---------------- diff --git a/docs/likelihoods.rst b/docs/likelihoods.rst index 78ce5294d..72f2d73cd 100644 --- a/docs/likelihoods.rst +++ b/docs/likelihoods.rst @@ -16,7 +16,7 @@ Main operations performed A likelihood object must be able to perform the following operations: -a. First of all, we require the \code{lpdf()} and \code{lpdf\_grid()} methods, which simply evaluate the loglikelihood in a given point or in a grid of points (also in case of a \emph{dependent} likelihood, i.e., with covariates associated to each observation) [``lpdf()`` and ``lpdf_grid``] +a. First of all, we require the ``lpdf()`` and ``lpdf\_grid()`` methods, which simply evaluate the loglikelihood in a given point or in a grid of points (also in case of a \emph{dependent} likelihood, i.e., with covariates associated to each observation) [``lpdf()`` and ``lpdf_grid``] b. In case you want to rely on a Metropolis-like updater, the likelihood needs to evaluation of the likelihood of the whole cluster starting from the vector of unconstrained parameters [``cluster_lpdf_from_unconstrained()``]. Observe that the ``AbstractLikelihood`` class provides two such methods, one returning a ``double`` and one returning a ``stan::math::var``. The latter is used to automatically compute the gradient of the likelihood via Stan's automatic differentiation, if needed. In practice, users do not need to implement both methods separately and can implement only one templated method c. manage the insertion and deletion of a datum in the cluster [``add_datum`` and ``remove_datum``] d. update the summary statistics associated to the likelihood [``update_summary_statistics``]. Summary statistics (when available) are used to evaluate the likelihood function on the whole cluster, as well as to perform the posterior updates of :math:`\bm{\tau}`. This usually gives a substantial speedup diff --git a/docs/mixings.rst b/docs/mixings.rst index 67d2b1074..64f3b593f 100644 --- a/docs/mixings.rst +++ b/docs/mixings.rst @@ -34,6 +34,9 @@ Classes .. doxygenclass:: PitYorMixing :project: bayesmix :members: +.. doxygenclass:: MixtureFiniteMixing + :project: bayesmix + :members: .. doxygenclass:: TruncatedSBMixing :project: bayesmix :members: diff --git a/docs/protos.html b/docs/protos.html index abc4c641b..26b9929a3 100644 --- a/docs/protos.html +++ b/docs/protos.html @@ -3,8 +3,12 @@ Protocol Documentation - - + + - - + -

Protocol Documentation

Table of Contents

- - -
-

algorithm_id.proto

Top -
-

- - - - -

AlgorithmId

-

Enum for the different types of algorithms.

References

[1] R. M. Neal, Markov Chain Sampling Methods for Dirichlet Process Mixture Models. JCGS(2000)

[2] H. Ishwaran and L. F. James, Gibbs Sampling Methods for Stick-Breaking Priors. JASA(2001)

[3] S. Jain and R. M. Neal, A Split-Merge Markov Chain Monte Carlo Procedure for the Dirichlet Process Mixture Model. JCGS (2004)

[4] M. Kalli, J. Griffin and S. G. Walker, Slice sampling mixture models. Stat and Comp. (2011)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_ALGORITHM0

Neal21

Neal's Algorithm 2, see [1]

Neal32

Neal's Algorithm 3, see [1]

Neal83

Neal's Algorithm 8, see [1]

BlockedGibbs4

Ishwaran and James Blocked Gibbs, see [2]

SplitMerge5

Jain and Neal's Split&Merge, see [3]. NOT IMPLEMENTED YET!

Slice6

Slice sampling, see [4]. NOT IMPLEMENTED YET!

- - - - - - - -
-

algorithm_params.proto

Top -
-

- - -

AlgorithmParams

-

Parameters used in the BaseAlgorithm class and childs.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
algo_idstring

Id of the Algorithm. Must match the ones in the AlgorithmId enum

rng_seeduint32

Seed for the random number generator

iterationsuint32

Total number of iterations of the MCMC chain

burninuint32

Number of iterations to discard as burn-in

init_num_clustersuint32

Number of clusters to initialize the algorithm. It may be overridden by conditional mixings for which the number of components is fixed (e.g. TruncatedSBMixing). In this case, this value is ignored.

neal8_n_auxuint32

Number of auxiliary unique values for the Neal8 algorithm

splitmerge_n_restr_gs_updatesuint32

Number of restricted GS scans for each MH step.

splitmerge_n_mh_updatesuint32

Number of MH updates for each iteration of Split and Merge algorithm.

splitmerge_n_full_gs_updatesuint32

Number of full GS scans for each iteration of Split and Merge algorithm.

- - - - - - - - - - - - - -
-

algorithm_state.proto

Top -
-

- - -

AlgorithmState

-

This message represents the state of a Gibbs sampler for

a mixture model. All algorithms must be able to handle this

message, by filling it with the current state of the sampler

in the `get_state_as_proto` method.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
cluster_statesAlgorithmState.ClusterStaterepeated

The state of each cluster

cluster_allocsint32repeated

Vector of allocations into clusters, one for each observation

mixing_stateMixingState

The state of the `Mixing`

iteration_numint32

The iteration number

hierarchy_hypersAlgorithmState.HierarchyHypers

The current values of the hyperparameters of the hierarchy

- - - - - -

AlgorithmState.ClusterState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
uni_ls_stateUniLSState

State of a univariate location-scale family

multi_ls_stateMultiLSState

State of a multivariate location-scale family

lin_reg_uni_ls_stateLinRegUniLSState

State of a linear regression univariate location-scale family

general_stateVector

Just a vector of doubles

fa_stateFAState

State of a Mixture of Factor Analysers

cardinalityint32

How many observations are in this cluster

- - - - - -

AlgorithmState.HierarchyHypers

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fake_priorEmptyPrior

nnig_stateNIGDistribution

nnw_stateNWDistribution

lin_reg_uni_stateMultiNormalIGDistribution

lapnig_stateLapNIGState

fa_stateFAPriorDistribution

- - - - - - - - - - - - - -
-

distribution.proto

Top -
-

- - -

BetaDistribution

-

Parameters defining a beta distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shape_adouble

shape_bdouble

- - - - - -

GammaDistribution

-

Parameters defining a gamma distribution with density

f(x) = x^(shape-1) * exp(-rate * x) / Gamma(shape)

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shapedouble

ratedouble

- - - - - -

InvWishartDistribution

-

Parameters defining an Inverse Wishart distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
deg_freedouble

scaleMatrix

- - - - - -

MultiNormalDistribution

-

Parameters defining a multivariate normal distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

varMatrix

- - - - - -

MultiNormalIGDistribution

-

Parameters for the Normal Inverse Gamma distribution commonly employed in

linear regression models, with density

f(beta, var) = N(beta | mean, var * var_scaling^{-1}) * IG(var | shape, scale)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

var_scalingMatrix

shapedouble

scaledouble

- - - - - -

NIGDistribution

-

Parameters of a Normal Inverse-Gamma distribution

with density

f(x, y) = N(x | mu, y/var_scaling) * IG(y | shape, scale)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

var_scalingdouble

shapedouble

scaledouble

- - - - - -

NWDistribution

-

Parameters of a Normal Wishart distribution

with density

f(x, y) = N(x | mu, (y * var_scaling)^{-1}) * IW(y | deg_free, scale)

where x is a vector and y is a matrix (spd)

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

var_scalingdouble

deg_freedouble

scaleMatrix

- - - - - -

UniNormalDistribution

-

Parameters defining a univariate normal distribution

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

- - - - - - - - - - - - - -
-

hierarchy_id.proto

Top -
-

- - - - -

HierarchyId

-

Enum for the different types of Hierarchy.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_HIERARCHY0

NNIG1

Normal - Normal Inverse Gamma

NNW2

Normal - Normal Wishart

LinRegUni3

Linear Regression (univariate response)

LapNIG4

Laplace - Normal Inverse Gamma

FA5

Factor Analysers

- - - - - - - -
-

hierarchy_prior.proto

Top -
-

- - -

EmptyPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fake_fielddouble

- - - - - -

FAPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesFAPriorDistribution

- - - - - -

FAPriorDistribution

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mutildeVector

betaVector

phidouble

alpha0double

quint32

- - - - - -

LapNIGPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesLapNIGState

- - - - - -

LapNIGState

-

Prior for the parameters of the base measure in a Laplace - Normal Inverse Gamma hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

mh_mean_vardouble

mh_log_scale_vardouble

- - - - - -

LinRegUniPrior

-

Prior for the parameters of the base measure in a Normal mixture model with a covariate-dependent

location.

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesMultiNormalIGDistribution

- - - - - -

NNIGPrior

-

Prior for the parameters of the base measure in a Normal-Normal Inverse Gamma hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesNIGDistribution

no prior, just fixed values

normal_mean_priorNNIGPrior.NormalMeanPrior

prior on the mean

ngg_priorNNIGPrior.NGGPrior

prior on the mean, var_scaling, and scale

- - - - - -

NNIGPrior.NGGPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorUniNormalDistribution

var_scaling_priorGammaDistribution

shapedouble

scale_priorGammaDistribution

- - - - - -

NNIGPrior.NormalMeanPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorUniNormalDistribution

var_scalingdouble

shapedouble

scaledouble

- - - - - -

NNWPrior

-

Prior for the parameters of the base measure in a Normal-Normal Wishart hierarchy

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesNWDistribution

no prior, just fixed values

normal_mean_priorNNWPrior.NormalMeanPrior

prior on the mean

ngiw_priorNNWPrior.NGIWPrior

prior on the mean, var_scaling, and scale

- - - - - -

NNWPrior.NGIWPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorMultiNormalDistribution

var_scaling_priorGammaDistribution

deg_freedouble

scale_priorInvWishartDistribution

- - - - - -

NNWPrior.NormalMeanPrior

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
mean_priorMultiNormalDistribution

var_scalingdouble

deg_freedouble

scaleMatrix

- - - - - - - - - - - - - -
-

ls_state.proto

Top -
-

- - -

FAState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
muVector

psiVector

etaMatrix

lambdaMatrix

- - - - - -

LinRegUniLSState

-

Parameters of a univariate linear regression

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
regression_coeffsVector

regression coefficients

vardouble

variance of the noise

- - - - - -

MultiLSState

-

Parameters of a multivariate location-scale family of distributions,

parameterized by mean and precision (inverse of variance). For

convenience, we also store the Cholesky factor of the precision matrix.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meanVector

precMatrix

prec_cholMatrix

- - - - - -

UniLSState

-

Parameters of a univariate location-scale family of distributions.

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
meandouble

vardouble

- - - - - - - - - - - - - -
-

matrix.proto

Top -
-

- - -

Matrix

-

Message representing a matrix of doubles.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
rowsint32

number of rows

colsint32

number of columns

datadoublerepeated

matrix elements

rowmajorbool

if true, the data is read in row-major order

- - - - - -

Vector

-

Message representing a vector of doubles.

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
sizeint32

number of elements in the vector

datadoublerepeated

vector elements

- - - - - - - - - - - - - -
-

mixing_id.proto

Top -
-

- - - - -

MixingId

-

Enum for the different types of Mixing.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameNumberDescription
UNKNOWN_MIXING0

DP1

Dirichlet Process

PY2

Pitman-Yor Process

LogSB3

Logit Stick-Breaking Process

TruncSB4

Truncated Stick-Breaking Process

MFM5

Mixture of finite mixtures

- - - - - - - -
-

mixing_prior.proto

Top -
-

- - -

DPPrior

-

Prior for the concentration parameter of a Dirichlet process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valueDPState

No prior, just a fixed value

gamma_priorDPPrior.GammaPrior

Gamma prior on the total mass

- - - - - -

DPPrior.GammaPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmass_priorGammaDistribution

- - - - - -

LogSBPrior

-

Definition of the parameters of a Logit-Stick Breaking process.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
normal_priorMultiNormalDistribution

Normal prior on the regression coefficients

step_sizedouble

Steps size for the MALA algorithm used for posterior inference (TODO: move?)

num_componentsuint32

Number of components in the process

- - - - - -

MFMPrior

-

Prior for the Poisson rate and Dirichlet parameters of a MFM (Finite Dirichlet) process.

For the moment, we only support fixed values

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valueMFMState

No prior, just a fixed value

- - - - - -

PYPrior

-

Prior for the strength and discount parameters of a Pitman-Yor process.

For the moment, we only support fixed values

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
fixed_valuesPYState

- - - - - -

TruncSBPrior

-

Definition of the parameters of a truncated Stick-Breaking process

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
beta_priorsTruncSBPrior.BetaPriors

General stick-breaking distributions

dp_priorTruncSBPrior.DPPrior

Truncated Dirichlet process

py_priorTruncSBPrior.PYPrior

Truncated Pitman-Yor process

mfm_priorTruncSBPrior.MFMPrior

num_componentsuint32

Number of components in the process

- - - - - -

TruncSBPrior.BetaPriors

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
beta_distributionsBetaDistributionrepeated

General stick-breaking distributions

- - - - - -

TruncSBPrior.DPPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

- - - - - -

TruncSBPrior.MFMPrior

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

- - - - - -

TruncSBPrior.PYPrior

-

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
strengthdouble

Truncated Pitman-Yor process

discountdouble

- - - - - - - - - - - - - -
-

mixing_state.proto

Top -
-

- - -

DPState

-

State of a Dirichlet process

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
totalmassdouble

the total mass of the DP

- - - - - -

LogSBState

-

State of a Logit-Stick Breaking process

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
regression_coeffsMatrix

Num_Components x Num_Features matrix. Each row is the regression coefficients for a component.

- - - - - -

MFMState

-

State of a MFM (Finite Dirichlet) process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
lambdadouble

rate parameter of Poisson prior on number of compunents of the MFM

gammadouble

parameter of the dirichlet distribution for the mixing weights

- - - - - -

MixingState

-

Wrapper of all possible mixing states into a single oneof

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
dp_stateDPState

py_statePYState

log_sb_stateLogSBState

trunc_sb_stateTruncSBState

mfm_stateMFMState

- - - - - -

PYState

-

State of a Pitman-Yor process

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
strengthdouble

discountdouble

- - - - - -

TruncSBState

-

State of a truncated sitck breaking process. For convenice we store also the logarithm of the weights

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
sticksVector

logweightsVector

- - - - - - - - - - - - - -
-

semihdp.proto

Top -
-

- - -

SemiHdpParams

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
pseudo_priorSemiHdpParams.PseudoPriorParams

dirichlet_concentrationdouble

rest_allocs_updatestring

Either "full", "metro_base", "metro_dist"

totalmass_restdouble

totalmass_hdpdouble

w_priorSemiHdpParams.WPriorParams

- - - - - -

SemiHdpParams.PseudoPriorParams

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
card_weightdouble

mean_perturb_sddouble

var_perturb_fracdouble

- - - - - -

SemiHdpParams.WPriorParams

-

- - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
shape1double

shape2double

- - - - - -

SemiHdpState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
restaurantsSemiHdpState.RestaurantStaterepeated

groupsSemiHdpState.GroupStaterepeated

tausSemiHdpState.ClusterStaterepeated

cint32repeated

wdouble

- - - - - -

SemiHdpState.ClusterState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
uni_ls_stateUniLSState

multi_ls_stateMultiLSState

lin_reg_uni_ls_stateLinRegUniLSState

cardinalityint32

- - - - - -

SemiHdpState.GroupState

-

- - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
cluster_allocsint32repeated

- - - - - -

SemiHdpState.RestaurantState

-

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FieldTypeLabelDescription
theta_starsSemiHdpState.ClusterStaterepeated

n_by_clusint32repeated

table_to_sharedint32repeated

table_to_idioint32repeated

- - - - - - - - - - - - +
+

algorithm_id.proto

+ Top +
+

+ +

AlgorithmId

+

Enum for the different types of algorithms.

+

References

+

+ [1] R. M. Neal, Markov Chain Sampling Methods for Dirichlet Process + Mixture Models. JCGS(2000) +

+

+ [2] H. Ishwaran and L. F. James, Gibbs Sampling Methods for Stick-Breaking + Priors. JASA(2001) +

+

+ [3] S. Jain and R. M. Neal, A Split-Merge Markov Chain Monte Carlo + Procedure for the Dirichlet Process Mixture Model. JCGS (2004) +

+

+ [4] M. Kalli, J. Griffin and S. G. Walker, Slice sampling mixture models. + Stat and Comp. (2011) +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_ALGORITHM0

Neal21

Neal's Algorithm 2, see [1]

Neal32

Neal's Algorithm 3, see [1]

Neal83

Neal's Algorithm 8, see [1]

BlockedGibbs4

Ishwaran and James Blocked Gibbs, see [2]

SplitMerge5 +

+ Jain and Neal's Split&Merge, see [3]. NOT IMPLEMENTED YET! +

+
Slice6

Slice sampling, see [4]. NOT IMPLEMENTED YET!

+ +
+

algorithm_params.proto

+ Top +
+

+ +

AlgorithmParams

+

Parameters used in the BaseAlgorithm class and childs.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
algo_idstring +

+ Id of the Algorithm. Must match the ones in the AlgorithmId enum +

+
rng_seeduint32

Seed for the random number generator

iterationsuint32

Total number of iterations of the MCMC chain

burninuint32

Number of iterations to discard as burn-in

init_num_clustersuint32 +

+ Number of clusters to initialize the algorithm. It may be + overridden by conditional mixings for which the number of + components is fixed (e.g. TruncatedSBMixing). In this case, this + value is ignored. +

+
neal8_n_auxuint32 +

Number of auxiliary unique values for the Neal8 algorithm

+
splitmerge_n_restr_gs_updatesuint32

Number of restricted GS scans for each MH step.

splitmerge_n_mh_updatesuint32 +

+ Number of MH updates for each iteration of Split and Merge + algorithm. +

+
splitmerge_n_full_gs_updatesuint32 +

+ Number of full GS scans for each iteration of Split and Merge + algorithm. +

+
+ +
+

algorithm_state.proto

+ Top +
+

+ +

AlgorithmState

+

This message represents the state of a Gibbs sampler for

+

a mixture model. All algorithms must be able to handle this

+

message, by filling it with the current state of the sampler

+

in the `get_state_as_proto` method.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
cluster_states + AlgorithmState.ClusterState + repeated

The state of each cluster

cluster_allocsint32repeated +

Vector of allocations into clusters, one for each observation

+
mixing_stateMixingState

The state of the `Mixing`

iteration_numint32

The iteration number

hierarchy_hypers + AlgorithmState.HierarchyHypers + +

The current values of the hyperparameters of the hierarchy

+
+ +

+ AlgorithmState.ClusterState +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
uni_ls_stateUniLSState

State of a univariate location-scale family

multi_ls_stateMultiLSState

State of a multivariate location-scale family

lin_reg_uni_ls_stateLinRegUniLSState +

State of a linear regression univariate location-scale family

+
general_stateVector

Just a vector of doubles

fa_stateFAState

State of a Mixture of Factor Analysers

cardinalityint32

How many observations are in this cluster

+ +

+ AlgorithmState.HierarchyHypers +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
general_stateVector

nnig_stateNIGDistribution

nnw_stateNWDistribution

lin_reg_uni_state + MultiNormalIGDistribution +

nnxig_stateNxIGDistribution

fa_state + FAPriorDistribution +

+ +
+

distribution.proto

+ Top +
+

+ +

BetaDistribution

+

Parameters defining a beta distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shape_adouble

shape_bdouble

+ +

GammaDistribution

+

Parameters defining a gamma distribution with density

+

f(x) = x^(shape-1) * exp(-rate * x) / Gamma(shape)

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shapedouble

ratedouble

+ +

InvWishartDistribution

+

Parameters defining an Inverse Wishart distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
deg_freedouble

scaleMatrix

+ +

MultiNormalDistribution

+

Parameters defining a multivariate normal distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

varMatrix

+ +

MultiNormalIGDistribution

+

+ Parameters for the Normal Inverse Gamma distribution commonly employed in +

+

linear regression models, with density

+

+ f(beta, var) = N(beta | mean, var * var_scaling^{-1}) * IG(var | shape, + scale) +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

var_scalingMatrix

shapedouble

scaledouble

+ +

NIGDistribution

+

Parameters of a Normal Inverse-Gamma distribution

+

with density

+

f(x, y) = N(x | mu, y/var_scaling) * IG(y | shape, scale)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

var_scalingdouble

shapedouble

scaledouble

+ +

NWDistribution

+

Parameters of a Normal Wishart distribution

+

with density

+

f(x, y) = N(x | mu, (y * var_scaling)^{-1}) * IW(y | deg_free, scale)

+

where x is a vector and y is a matrix (spd)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

var_scalingdouble

deg_freedouble

scaleMatrix

+ +

NxIGDistribution

+

Parameters of a Normal x Inverse-Gamma distribution

+

with density

+

f(x, y) = N(x | mu, var) * IG(y | shape, scale)

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

+ +

UniNormalDistribution

+

Parameters defining a univariate normal distribution

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

+ +
+

hierarchy_id.proto

+ Top +
+

+ +

HierarchyId

+

Enum for the different types of Hierarchy.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_HIERARCHY0

NNIG1

Normal - Normal Inverse Gamma

NNW2

Normal - Normal Wishart

LinRegUni3

Linear Regression (univariate response)

LapNIG4

Laplace - Normal Inverse Gamma

FA5

Factor Analysers

NNxIG6

Normal - Normal x Inverse Gamma

+ +
+

hierarchy_prior.proto

+ Top +
+

+ +

EmptyPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fake_fielddouble

+ +

FAPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_values + FAPriorDistribution +

+ +

FAPriorDistribution

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mutildeVector

betaVector

phidouble

alpha0double

quint32

+ +

LapNIGPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesLapNIGState

+ +

LapNIGState

+

+ Prior for the parameters of the base measure in a Laplace - Normal Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

shapedouble

scaledouble

mh_mean_vardouble

mh_log_scale_vardouble

+ +

LinRegUniPrior

+

+ Prior for the parameters of the base measure in a Normal mixture model + with a covariate-dependent +

+

location.

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_values + MultiNormalIGDistribution +

+ +

NNIGPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNIGDistribution

no prior, just fixed values

normal_mean_prior + NNIGPrior.NormalMeanPrior +

prior on the mean

ngg_priorNNIGPrior.NGGPrior

prior on the mean, var_scaling, and scale

+ +

NNIGPrior.NGGPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + UniNormalDistribution +

var_scaling_priorGammaDistribution

shapedouble

scale_priorGammaDistribution

+ +

NNIGPrior.NormalMeanPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + UniNormalDistribution +

var_scalingdouble

shapedouble

scaledouble

+ +

NNWPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal Wishart + hierarchy +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNWDistribution

no prior, just fixed values

normal_mean_prior + NNWPrior.NormalMeanPrior +

prior on the mean

ngiw_priorNNWPrior.NGIWPrior

prior on the mean, var_scaling, and scale

+ +

NNWPrior.NGIWPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + MultiNormalDistribution +

var_scaling_priorGammaDistribution

deg_freedouble

scale_prior + InvWishartDistribution +

+ +

NNWPrior.NormalMeanPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mean_prior + MultiNormalDistribution +

var_scalingdouble

deg_freedouble

scaleMatrix

+ +

NNxIGPrior

+

+ Prior for the parameters of the base measure in a Normal-Normal x Inverse + Gamma hierarchy +

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesNxIGDistribution

no prior, just fixed values

+ +
+

ls_state.proto

+ Top +
+

+ +

FAState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
muVector

psiVector

etaMatrix

lambdaMatrix

+ +

LinRegUniLSState

+

Parameters of a univariate linear regression

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
regression_coeffsVector

regression coefficients

vardouble

variance of the noise

+ +

MultiLSState

+

Parameters of a multivariate location-scale family of distributions,

+

parameterized by mean and precision (inverse of variance). For

+

+ convenience, we also store the Cholesky factor of the precision matrix. +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meanVector

precMatrix

prec_cholMatrix

+ +

UniLSState

+

Parameters of a univariate location-scale family of distributions.

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
meandouble

vardouble

+ +
+

matrix.proto

+ Top +
+

+ +

Matrix

+

Message representing a matrix of doubles.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
rowsint32

number of rows

colsint32

number of columns

datadoublerepeated

matrix elements

rowmajorbool

if true, the data is read in row-major order

+ +

Vector

+

Message representing a vector of doubles.

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
sizeint32

number of elements in the vector

datadoublerepeated

vector elements

+ +
+

mixing_id.proto

+ Top +
+

+ +

MixingId

+

Enum for the different types of Mixing.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameNumberDescription
UNKNOWN_MIXING0

DP1

Dirichlet Process

PY2

Pitman-Yor Process

LogSB3

Logit Stick-Breaking Process

TruncSB4

Truncated Stick-Breaking Process

MFM5

Mixture of finite mixtures

+ +
+

mixing_prior.proto

+ Top +
+

+ +

DPPrior

+

Prior for the concentration parameter of a Dirichlet process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valueDPState

No prior, just a fixed value

gamma_priorDPPrior.GammaPrior

Gamma prior on the total mass

+ +

DPPrior.GammaPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmass_priorGammaDistribution

+ +

LogSBPrior

+

Definition of the parameters of a Logit-Stick Breaking process.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
normal_prior + MultiNormalDistribution +

Normal prior on the regression coefficients

step_sizedouble +

+ Steps size for the MALA algorithm used for posterior inference + (TODO: move?) +

+
num_componentsuint32

Number of components in the process

+ +

MFMPrior

+

+ Prior for the Poisson rate and Dirichlet parameters of a MFM (Finite + Dirichlet) process. +

+

For the moment, we only support fixed values

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valueMFMState

No prior, just a fixed value

+ +

PYPrior

+

+ Prior for the strength and discount parameters of a Pitman-Yor process. +

+

For the moment, we only support fixed values

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
fixed_valuesPYState

+ +

TruncSBPrior

+

Definition of the parameters of a truncated Stick-Breaking process

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
beta_priors + TruncSBPrior.BetaPriors +

General stick-breaking distributions

dp_prior + TruncSBPrior.DPPrior +

Truncated Dirichlet process

py_prior + TruncSBPrior.PYPrior +

Truncated Pitman-Yor process

mfm_prior + TruncSBPrior.MFMPrior +

num_componentsuint32

Number of components in the process

+ +

TruncSBPrior.BetaPriors

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
beta_distributionsBetaDistributionrepeated

General stick-breaking distributions

+ +

TruncSBPrior.DPPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

+ +

TruncSBPrior.MFMPrior

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

Truncated Dirichlet process

+ +

TruncSBPrior.PYPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
strengthdouble

Truncated Pitman-Yor process

discountdouble

+ +
+

mixing_state.proto

+ Top +
+

+ +

DPState

+

State of a Dirichlet process

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
totalmassdouble

the total mass of the DP

+ +

LogSBState

+

State of a Logit-Stick Breaking process

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
regression_coeffsMatrix +

+ Num_Components x Num_Features matrix. Each row is the regression + coefficients for a component. +

+
+ +

MFMState

+

State of a MFM (Finite Dirichlet) process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
lambdadouble +

+ rate parameter of Poisson prior on number of compunents of the MFM +

+
gammadouble +

+ parameter of the dirichlet distribution for the mixing weights +

+
+ +

MixingState

+

Wrapper of all possible mixing states into a single oneof

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
dp_stateDPState

py_statePYState

log_sb_stateLogSBState

trunc_sb_stateTruncSBState

mfm_stateMFMState

+ +

PYState

+

State of a Pitman-Yor process

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
strengthdouble

discountdouble

+ +

TruncSBState

+

+ State of a truncated sitck breaking process. For convenice we store also + the logarithm of the weights +

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
sticksVector

logweightsVector

+ +
+

mixture_model.proto

+ Top +
+

+ +

HierarchyPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
nnig_priorNNIGPrior

lapnig_priorLapNIGPrior

nnw_priorNNWPrior

lin_reg_priorLinRegUniPrior

fa_priorFAPrior

+ +

MixingPrior

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
dp_priorDPPrior

py_priorPYPrior

log_sb_priorLogSBPrior

trunc_sb_priorTruncSBPrior

+ +

MixtureModel

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
mixingMixingId

hierarchyHierarchyId

mixing_priorMixingPrior

hierarchy_priorHierarchyPrior

+ +
+

semihdp.proto

+ Top +
+

+ +

SemiHdpParams

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
pseudo_prior + SemiHdpParams.PseudoPriorParams +

dirichlet_concentrationdouble

rest_allocs_updatestring +

+ Either "full", "metro_base", "metro_dist" +

+
totalmass_restdouble

totalmass_hdpdouble

w_prior + SemiHdpParams.WPriorParams +

+ +

+ SemiHdpParams.PseudoPriorParams +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
card_weightdouble

mean_perturb_sddouble

var_perturb_fracdouble

+ +

SemiHdpParams.WPriorParams

+

+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
shape1double

shape2double

+ +

SemiHdpState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
restaurants + SemiHdpState.RestaurantState + repeated

groups + SemiHdpState.GroupState + repeated

taus + SemiHdpState.ClusterState + repeated

cint32repeated

wdouble

+ +

SemiHdpState.ClusterState

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
uni_ls_stateUniLSState

multi_ls_stateMultiLSState

lin_reg_uni_ls_stateLinRegUniLSState

cardinalityint32

+ +

SemiHdpState.GroupState

+

+ + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
cluster_allocsint32repeated

+ +

+ SemiHdpState.RestaurantState +

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldTypeLabelDescription
theta_stars + SemiHdpState.ClusterState + repeated

n_by_clusint32repeated

table_to_sharedint32repeated

table_to_idioint32repeated

Scalar Value Types

- + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
.proto TypeNotesC++JavaPythonGoC#PHPRuby
.proto TypeNotesC++JavaPythonGoC#PHPRuby
doubledoubledoublefloatfloat64doublefloatFloat
floatfloatfloatfloatfloat32floatfloatFloat
int32Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead.int32intintint32intintegerBignum or Fixnum (as required)
int64Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead.int64longint/longint64longinteger/stringBignum
uint32Uses variable-length encoding.uint32intint/longuint32uintintegerBignum or Fixnum (as required)
uint64Uses variable-length encoding.uint64longint/longuint64ulonginteger/stringBignum or Fixnum (as required)
sint32Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s.int32intintint32intintegerBignum or Fixnum (as required)
sint64Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s.int64longint/longint64longinteger/stringBignum
fixed32Always four bytes. More efficient than uint32 if values are often greater than 2^28.uint32intintuint32uintintegerBignum or Fixnum (as required)
fixed64Always eight bytes. More efficient than uint64 if values are often greater than 2^56.uint64longint/longuint64ulonginteger/stringBignum
sfixed32Always four bytes.int32intintint32intintegerBignum or Fixnum (as required)
sfixed64Always eight bytes.int64longint/longint64longinteger/stringBignum
boolboolbooleanbooleanboolboolbooleanTrueClass/FalseClass
stringA string must always contain UTF-8 encoded or 7-bit ASCII text.stringStringstr/unicodestringstringstringString (UTF-8)
bytesMay contain any arbitrary sequence of bytes.stringByteStringstr[]byteByteStringstringString (ASCII-8BIT)
doubledoubledoublefloatfloat64doublefloatFloat
floatfloatfloatfloatfloat32floatfloatFloat
int32 + Uses variable-length encoding. Inefficient for encoding negative + numbers – if your field is likely to have negative values, use + sint32 instead. + int32intintint32intintegerBignum or Fixnum (as required)
int64 + Uses variable-length encoding. Inefficient for encoding negative + numbers – if your field is likely to have negative values, use + sint64 instead. + int64longint/longint64longinteger/stringBignum
uint32Uses variable-length encoding.uint32intint/longuint32uintintegerBignum or Fixnum (as required)
uint64Uses variable-length encoding.uint64longint/longuint64ulonginteger/stringBignum or Fixnum (as required)
sint32 + Uses variable-length encoding. Signed int value. These more + efficiently encode negative numbers than regular int32s. + int32intintint32intintegerBignum or Fixnum (as required)
sint64 + Uses variable-length encoding. Signed int value. These more + efficiently encode negative numbers than regular int64s. + int64longint/longint64longinteger/stringBignum
fixed32 + Always four bytes. More efficient than uint32 if values are often + greater than 2^28. + uint32intintuint32uintintegerBignum or Fixnum (as required)
fixed64 + Always eight bytes. More efficient than uint64 if values are often + greater than 2^56. + uint64longint/longuint64ulonginteger/stringBignum
sfixed32Always four bytes.int32intintint32intintegerBignum or Fixnum (as required)
sfixed64Always eight bytes.int64longint/longint64longinteger/stringBignum
boolboolbooleanbooleanboolboolbooleanTrueClass/FalseClass
string + A string must always contain UTF-8 encoded or 7-bit ASCII text. + stringStringstr/unicodestringstringstringString (UTF-8)
bytesMay contain any arbitrary sequence of bytes.stringByteStringstr[]byteByteStringstringString (ASCII-8BIT)
- From 6934cdcaab8f1d3fd11a55a45a460496ea6e510f Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 16:54:34 +0200 Subject: [PATCH 309/317] more submodules --- docs/prior_models.rst | 83 +++++++++++++++++++++++++++++++++++++++++++ docs/states.rst | 38 ++++++++++++++++++++ docs/updaters.rst | 77 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+) create mode 100644 docs/prior_models.rst create mode 100644 docs/states.rst create mode 100644 docs/updaters.rst diff --git a/docs/prior_models.rst b/docs/prior_models.rst new file mode 100644 index 000000000..8df7cc049 --- /dev/null +++ b/docs/prior_models.rst @@ -0,0 +1,83 @@ +bayesmix/hierarchies/prior_models + +Prior Models +============ + +A ``PriorModel`` represents the prior for the parameters in the likelihood, i.e. + +.. math:: + \bm{\tau} \sim G_{0} + +with :math:`G_{0}` being a suitable prior on the parameters space. We also allow for more flexible priors adding further level of randomness (i.e. the hyperprior) on the parameter characterizing :math:`G_{0}` + +------------------------- +Main operations performed +------------------------- + +A likelihood object must be able to perform the following operations: + +a. First of all, ``lpdf()`` and ``lpdf_from_unconstrained()`` methods evaluate the log-prior density function at the current state :math:`\bm \tau` or its unconstrained representation. +In particular, ``lpdf_from_unconstrained()`` is needed by Metropolis-like updaters. + +b. The ``sample()`` method generates a draw from the prior distribution. If ``hier_hypers`` is ``nullptr, the prior hyperparameter values are used. +To allow sampling from the full conditional distribution in case of semi-congugate hierarchies, we introduce the ``hier_hypers`` parameter, which is a pointer to a ``Protobuf`` message storing the hierarchy hyperaprameters to use for the sampling. + +c. The ``update_hypers()`` method updates the prior hyperparameters, given the vector of all cluster states. + + +-------------- +Code structure +-------------- + +As for the ``Likelihood`` classes we employ the Curiously Recurring Template Pattern to manage the polymorphic nature of ``PriorModel`` classes. + +The class ``AbstractPriorModel`` defines the API, i.e. all the methods that need to be called from outside of a ``PrioModel`` class. +A template class ``BasePriorModel`` inherits from ``AbstractPriorModel`` and implements some of the necessary virtual methods, which need not be implemented by the child classes. + +Instead, child classes **must** implement: + +a. ``lpdf``: evaluates :math:`G_0(\theta_h)` +b. ``sample``: samples from :math:`G_0` given a hyperparameters (passed as a pointer). If ``hier_hypers`` is ``nullptr``, the prior hyperparameter values are used. +c. ``set_hypers_from_proto``: sets the hyperparameters from a ``Probuf``message +d. ``get_hypers_proto``: returns the hyperparameters as a ``Probuf``message +e. ``initialize_hypers``: provides a default initialization of hyperparameters + +In case you want to use a Metropolis-like updater, child classes **should** also implement: + +f. ``lpdf_from_unconstrained``: evaluates :math:`G_0(\tilde{\theta}_h)`, where :math:`\tilde{\theta}_h` is the vector of unconstrained parameters. + +---------------- +Abstract Classes +---------------- + +.. doxygenclass:: AbstractPriorModel + :project: bayesmix + :members: +.. doxygenclass:: BasePriorModel + :project: bayesmix + :members: + +-------------------- +Non-abstract Classes +-------------------- + +.. doxygenclass:: NIGPriorModel + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NxIGPriorModel + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NWPriorModel + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: MNIGPriorModel + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: FAPriorModel + :project: bayesmix + :members: + :protected-members: diff --git a/docs/states.rst b/docs/states.rst new file mode 100644 index 000000000..6193081b1 --- /dev/null +++ b/docs/states.rst @@ -0,0 +1,38 @@ +bayesmix/hierarchies/likelihoods/states + +States +====== + +``States`` are classes used to store parameters :math:`\theta_h` of every mixture component. +Their main purpose is to handle serialization and de-serialization of the state. +Moreover, they allow to go from the constrained to the unconstrained representation of the parameters (and viceversa) and compute the associated determinant of the Jacobian appearing in the change of density formula. + + +-------------- +Code Structure +-------------- + +All classes must inherit from the `BaseState` class + +.. doxygenclass:: BaseState + :project: bayesmix + :members: + +Depending on the chosen ``Updater``, the unconstrained representation might not be needed, and the methods ``get_unconstrained()``, ``set_from_unconstrained()`` and ``log_det_jac()`` might never be called. +Therefore, we do not force users to implement them. +Instead, the ``set_from_proto()`` and ``get_as_proto()`` are fundamental as they allow the interaction with Google's Protocol Buffers library. + +------------- +State Classes +------------- + +.. doxygenclass:: UniLSState + :project: bayesmix + :members: +.. doxygenclass:: MultiLSState + :project: bayesmix + :members: +.. doxygenclass:: FAState + :project: bayesmix + :members: + :protected-members: diff --git a/docs/updaters.rst b/docs/updaters.rst new file mode 100644 index 000000000..9e00718ce --- /dev/null +++ b/docs/updaters.rst @@ -0,0 +1,77 @@ +bayesmix/hierarchies/updaters + +Updaters +======== + +An ``Updater`` implements the machinery to provide a sampling from the full conditional distribution of a given hierarchy. + +The only operation performed is ``draw`` that samples from the full conditional, either exactly or via Markov chain Monte Carlo. + +.. doxygenclass:: AbstractUpdater + :project: bayesmix + :members: + +-------------- +Code Structure +-------------- + +We distinguish between semi-conjugate updaters and the metropolis-like updaters. + + +Semi Conjugate Updaters +----------------------- + +A semi-conjugate updater can be used when the full conditional distribution has the same form of the prior. Therefore, to sample from the full conditional, it is sufficient to call the ``draw`` method of the prior, but with an updated set of hyperparameters. + +The class ``SemiConjugateUpdater`` defines the API + +.. doxygenclass:: SemiConjugateUpdater + :project: bayesmix + :members: + +Classes inheriting from this one should only implement the ``compute_posterior_hypers(...)`` member function. + + +Metropolis-like Updaters +------------------------ + +A Metropolis updater uses the Metropolis-Hastings algorithm (or its variations) to sample from the full conditional density. + +.. doxygenclass:: MetropolisUpdater + :project: bayesmix + :members: + + +Classes inheriting from this one should only implement the ``sample_proposal(...)`` method, which samples from the porposal distribution, and the ``proposal_lpdf`` one, which evaluates the proposal density log-probability density function. + +--------------- +Updater Classes +--------------- + +.. doxygenclass:: RandomWalkUpdater + :project: bayesmix + :members: +.. doxygenclass:: MalaUpdater + :project: bayesmix + :members: + +.. doxygenclass:: NNIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NNxIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: NNWUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: MNIGUpdater + :project: bayesmix + :members: + :protected-members: +.. doxygenclass:: FAUpdater + :project: bayesmix + :members: + :protected-members: From dddbda3d6406b9bdc4837d615232d7d8a6f44d5c Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 17:57:09 +0200 Subject: [PATCH 310/317] latexifying stuff --- docs/algorithms.rst | 2 +- docs/prior_models.rst | 8 +++- docs/states.rst | 10 +++-- src/algorithms/conditional_algorithm.h | 20 +++++----- src/algorithms/marginal_algorithm.h | 24 ++++++------ src/hierarchies/fa_hierarchy.h | 12 +++--- src/hierarchies/lapnig_hierarchy.h | 15 +++++--- .../likelihoods/states/base_state.h | 18 ++++----- src/hierarchies/likelihoods/states/fa_state.h | 13 ++++--- src/hierarchies/nnig_hierarchy.h | 11 ++++-- src/hierarchies/nnw_hierarchy.h | 15 +++++--- src/hierarchies/priors/fa_prior_model.h | 12 +++--- src/hierarchies/priors/mnig_prior_model.h | 6 ++- src/hierarchies/priors/nig_prior_model.h | 12 +++--- src/hierarchies/priors/nw_prior_model.h | 11 ++++-- src/hierarchies/priors/nxig_prior_model.h | 6 ++- src/hierarchies/updaters/nnxig_updater.h | 9 +++-- src/mixings/dirichlet_mixing.h | 21 +++++----- src/mixings/logit_sb_mixing.h | 23 +++++------ src/mixings/mixture_finite_mixing.h | 31 ++++++++------- src/mixings/pityor_mixing.h | 22 +++++------ src/mixings/truncated_sb_mixing.h | 38 ++++++++++--------- 22 files changed, 191 insertions(+), 148 deletions(-) diff --git a/docs/algorithms.rst b/docs/algorithms.rst index 635800167..e15aff394 100644 --- a/docs/algorithms.rst +++ b/docs/algorithms.rst @@ -20,7 +20,7 @@ Algorithms .. doxygenclass:: Neal8Algorithm :project: bayesmix :members: -.. doxygenclass:: SplitMergeAlgorithm +.. doxygenclass:: SplitAndMergeAlgorithm :project: bayesmix :members: .. doxygenclass:: ConditionalAlgorithm diff --git a/docs/prior_models.rst b/docs/prior_models.rst index 8df7cc049..002ab1f85 100644 --- a/docs/prior_models.rst +++ b/docs/prior_models.rst @@ -38,8 +38,8 @@ Instead, child classes **must** implement: a. ``lpdf``: evaluates :math:`G_0(\theta_h)` b. ``sample``: samples from :math:`G_0` given a hyperparameters (passed as a pointer). If ``hier_hypers`` is ``nullptr``, the prior hyperparameter values are used. -c. ``set_hypers_from_proto``: sets the hyperparameters from a ``Probuf``message -d. ``get_hypers_proto``: returns the hyperparameters as a ``Probuf``message +c. ``set_hypers_from_proto``: sets the hyperparameters from a ``Probuf`` message +d. ``get_hypers_proto``: returns the hyperparameters as a ``Probuf`` message e. ``initialize_hypers``: provides a default initialization of hyperparameters In case you want to use a Metropolis-like updater, child classes **should** also implement: @@ -65,18 +65,22 @@ Non-abstract Classes :project: bayesmix :members: :protected-members: + .. doxygenclass:: NxIGPriorModel :project: bayesmix :members: :protected-members: + .. doxygenclass:: NWPriorModel :project: bayesmix :members: :protected-members: + .. doxygenclass:: MNIGPriorModel :project: bayesmix :members: :protected-members: + .. doxygenclass:: FAPriorModel :project: bayesmix :members: diff --git a/docs/states.rst b/docs/states.rst index 6193081b1..b20fddaac 100644 --- a/docs/states.rst +++ b/docs/states.rst @@ -14,7 +14,7 @@ Code Structure All classes must inherit from the `BaseState` class -.. doxygenclass:: BaseState +.. doxygenclass:: State::BaseState :project: bayesmix :members: @@ -26,13 +26,15 @@ Instead, the ``set_from_proto()`` and ``get_as_proto()`` are fundamental as they State Classes ------------- -.. doxygenclass:: UniLSState +.. doxygenclass:: State::UniLS :project: bayesmix :members: -.. doxygenclass:: MultiLSState + +.. doxygenclass:: State::MultiLS :project: bayesmix :members: -.. doxygenclass:: FAState + +.. doxygenclass:: State::FA :project: bayesmix :members: :protected-members: diff --git a/src/algorithms/conditional_algorithm.h b/src/algorithms/conditional_algorithm.h index 988780c96..d83d5f7af 100644 --- a/src/algorithms/conditional_algorithm.h +++ b/src/algorithms/conditional_algorithm.h @@ -12,15 +12,17 @@ //! This template class implements a generic Gibbs sampling conditional //! algorithm as the child of the `BaseAlgorithm` class. //! A mixture model sampled from a conditional algorithm can be expressed as -//! x_i | c_i, phi_1, ..., phi_k ~ f(x_i|phi_(c_i)) (data likelihood); -//! phi_1, ... phi_k ~ G (unique values); -//! c_1, ... c_n | w_1, ..., w_k ~ Cat(w_1, ... w_k) (cluster allocations); -//! w_1, ..., w_k ~ p(w_1, ..., w_k) (mixture weights) -//! where f(x | phi_j) is a density for each value of phi_j, the c_i take -//! values in {1, ..., k} and w_1, ..., w_k are nonnegative weights whose sum -//! is a.s. 1, i.e. p(w_1, ... w_k) is a probability distribution on the k-1 -//! dimensional unit simplex). -//! In this library, each phi_j is represented as an `Hierarchy` object (which +//! \f[ +//! x_i | c_i, \theta_1, ..., \theta_k & \sim f(x_i|\theta_(c_i)) \\ +//! \theta_1, ..., \theta_k & \sim G_0 \\ +//! c_1, ... c_n | w_1, ..., w_k & \sim \text{Cat}(w_1, ... w_k) \\ +//! w_1, ..., w_k & \sim p(w_1, ..., w_k) +//! \f] +//! where \f$f(x | \theta_j)\f$ is a density for each value of \f$\theta_j\f$, +//! \f$c_i\f$ take values in \f${1, ..., k}\f$ and \f$w_1, ..., w_k\f$ are +//! nonnegative weights whose sum is a.s. 1, i.e. \f$p(w_1, ... w_k)\f$ is a +//! probability distribution on the k-1 dimensional unit simplex). n this +//! library, each \f$\theta_j\f$ is represented as an `Hierarchy` object (which //! inherits from `AbstractHierarchy`), that also holds the information related //! to the base measure `G` is (see `AbstractHierarchy`). //! The weights (w_1, ..., w_k) are represented as a `Mixing` object, which diff --git a/src/algorithms/marginal_algorithm.h b/src/algorithms/marginal_algorithm.h index 34af2167a..d0c880e10 100644 --- a/src/algorithms/marginal_algorithm.h +++ b/src/algorithms/marginal_algorithm.h @@ -13,17 +13,19 @@ //! This template class implements a generic Gibbs sampling marginal algorithm //! as the child of the `BaseAlgorithm` class. //! A mixture model sampled from a Marginal Algorithm can be expressed as -//! x_i | c_i, phi_1, ..., phi_k ~ f(x_i|phi_(c_i)) (data likelihood); -//! phi_1, ... phi_k ~ G (unique values); -//! c_1, ... c_n ~ EPPF(c_1, ... c_n) (cluster allocations); -//! where f(x | phi_j) is a density for each value of phi_j and the c_i take -//! values in {1, ..., k}. -//! Depending on the actual implementation, the algorithm might require -//! the kernel/likelihood f(x | phi) and G(phi) to be conjugagte or not. -//! In the former case, a `ConjugateHierarchy` must be specified. -//! In this library, each phi_j is represented as an `Hierarchy` object (which -//! inherits from `AbstractHierarchy`), that also holds the information related -//! to the base measure `G` is (see `AbstractHierarchy`). The EPPF is instead +//! \f[ +//! x_i | c_i, \theta_1, ..., \theta_k & \sim f(x_i|\theta_(c_i)) \\ +//! \theta_1, ..., \theta_k & \sim G_0 \\ +//! c_1, ... c_n & \sim EPPF(c_1, ... c_n) +//! \f] +//! where \f$f(x | \theta_j)\f$ is a density for each value of \f$\theta_j\f$ +//! and \f$c_i\f$ take values in \f${1, ..., k}\f$. Depending on the actual +//! implementation, the algorithm might require the kernel/likelihood \f$f(x | +//! \theta)\f$ and \f$G_0(phi)\f$ to be conjugagte or not. In the former case, +//! a conjugate hierarchy must be specified. In this library, each +//! \f$\theta_j\f$ is represented as an `Hierarchy` object (which inherits from +//! `AbstractHierarchy`), that also holds the information related to the base +//! measure \f$G_0\f$ is (see `AbstractHierarchy`). The EPPF is instead //! represented as a `Mixing` object, which inherits from `AbstractMixing`. //! //! The state of a marginal algorithm only consists of allocations and unique diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index f98dbc7c5..c0f8ecd63 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -15,11 +15,13 @@ //! of the covariance matrix (see the `FAHierarchy` class for details). The //! likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma //! centering distribution (see the `FAPriorModel` class for details). That is: -//! f(x_i|mu,Sigma,Lambda) = N(mu,Sigma+Lambda*Lambda^T) -//! mu ~ N(mu0,psi*I) -//! Lambda ~ DL(alpha) -//! Sigma = diag(sig1^2,...,sigp^2) -//! sigj^2 ~ IG(a,b) for j=1,...,p +//! \f[ +//! f(x_i| \mu, \Sigma, \Lambda) &= N(\mu, \Sigma + \Lambda \Lambda^T) \\ +//! \mu &\sim N_p(\tilde \mu, \psi I) \\ +//! \Lambda &\sim DL(\alpha) \\ +//! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ +//! \sigma^2_j &\sim IG(a,b) \quad j=1,...,p +//! \f] //! where Lambda is the latent score matrix (size p x d with d << p) and //! DL(alpha) is the Laplace-Dirichlet distribution. //! See Bhattacharya et al. (2015) for further details. diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h index 000b60bd5..d4803a5c8 100644 --- a/src/hierarchies/lapnig_hierarchy.h +++ b/src/hierarchies/lapnig_hierarchy.h @@ -13,13 +13,16 @@ //! according to a laplace likelihood (see the `LaplaceLikelihood` class for //! deatils).The likelihood parameters have a Normal x InverseGamma centering //! distribution (see the `NxIGPriorModel` class for details). That is: -//! f(x_i|mu,lambda) = Laplace(mu,sqrt(var/2)) -//! mu ~ N(mu0,sig0^2) -//! var ~ IG(alpha0,beta0) +//! \f[ +//! f(x_i|\mu,\sigma^2) &= Laplace(\mu,\sqrt(\sigma^2/2))\\ +//! \mu &\sim N(\mu_0,\eta^2) \\ +//! \sigma^2 ~ IG(a, b) +//! \f] //! The state is composed of mean and variance (thus the scale for the Laplace -//! distribution is sqrt(var / 2)). The state hyperparameters are (mu_0, -//! sig0^2, alpha0, beta0), all scalar values. Note that this hierarchy is NOT -//! conjugate, thus the marginal distribution is not available in closed form. +//! distribution is \f$ \sqrt(\sigma^2 / 2)) \f$. The state hyperparameters are +//! \f$(mu_0, \sigma^2, a, b)\f$, all scalar values. Note that this hierarchy +//! is NOT conjugate, thus the marginal distribution is not available in closed +//! form. class LapNIGHierarchy : public BaseHierarchy { diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index cb0744531..4217913d1 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -14,14 +14,17 @@ //! are distributed according to a multivariate normal likelihood (see the //! `MultiNormLikelihood` for details). The likelihood parameters have a //! Normal-Wishart centering distribution (see the `NWPriorModel` class for -//! details). That is: f(x_i|mu,tau) = N(mu,tau^{-1}) -//! (mu,tau) ~ NW(mu0, lambda0, tau0, nu0) +//! details). That is: +//! \f[ +//! f(x_i|\mu,\Sigma) &= N(\mu,\Sigma^{-1}) \\ +//! (\mu,\Sigma) &\sim NW(\mu_0, \lambda, \Psi_0, \nu_0) +//! \f] //! The state is composed of mean and precision matrix. The Cholesky factor and //! log-determinant of the latter are also included in the container for -//! efficiency reasons. The state's hyperparameters are (mu0, lambda0, tau0, -//! nu0), which are respectively vector, scalar, matrix, and scalar. Note that -//! this hierarchy is conjugate, thus the marginal distribution is available in -//! closed form. +//! efficiency reasons. The state's hyperparameters are \f$(\mu_0, \lambda, +//! \Psi_0, \nu_0)\f$, which are respectively vector, scalar, matrix, and +//! scalar. Note that this hierarchy is conjugate, thus the marginal +//! distribution is available in closed form. class NNWHierarchy : public BaseHierarchy { diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index f35ac03fa..d1689d74f 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -10,11 +10,13 @@ #include "hyperparams.h" #include "src/utils/rng.h" -//! A prior model for the factor analyzers likelihood, that is -//! mu ~ N_p(mutilde, psi*I) -//! Lambda ~ DL(alpha) -//! Sigma = diag(sigsq_1,...,sigsq_p) -//! sigsq_j ~ IG(a,b) j=1,...,p +//! A priormodel for the factor analyzers likelihood, that is +//! \f[ +//! \mu &\sim N_p(\tilde \mu, \psi I) \\ +//! \Lambda &\sim DL(\alpha) \\ +//! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ +//! \sigma^2_j &\sim IG(a,b) \quad j=1,...,p +//! \f] //! Where DL is the Dirichlet-Laplace distribution. See Bhattacharya A., Pati //! D, Pillai N.S., Dunson D.B. (2015). JASA 110(512), 1479–1490 for details. diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index 885993db9..1202ab551 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -11,8 +11,10 @@ #include "src/utils/rng.h" //! A conjugate prior model for the scalar linear regression likelihood, i.e. -//! reg_coeffs | var ~ N_p(mu, var * Lambda^-1) -//! var ~ IG(a,b) +//! \f[ +//! \beta | \sigma^2 & \sim N_p(\mu, \sigma^2 \Lambda^-1) \\ +//! \sigma^2 & \sim IG(a,b) +//! \f] class MNIGPriorModel : public BasePriorModel { public: diff --git a/src/mixings/logit_sb_mixing.h b/src/mixings/logit_sb_mixing.h index 228f4da40..51fd33fff 100644 --- a/src/mixings/logit_sb_mixing.h +++ b/src/mixings/logit_sb_mixing.h @@ -12,15 +12,22 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" +namespace LogitSB { +struct State { + Eigen::MatrixXd regression_coeffs, precision; +}; +}; // namespace LogitSB + //! Class that represents the logit stick-breaking process indroduced in Rigon //! and Durante (2020). //! That is, a prior for weights (w_1,...,w_H), depending on covariates x in //! R^p, in the H-1 dimensional unit simplex, defined as follows: -//! w_1(x) = v_1(x) -//! w_j(x) = v_j(x) (1 - v_1(x)) ... (1 - v_{j-1}(x)), for j=2, ... H-1 -//! w_H(x) = 1 - (w_1(x) + w_2 + ... + w_{H-1}(x)) -//! and -//! v_j(x) = 1 / exp(- ), for j = 1, ..., H-1 +//! \f[ +//! w_1 &= v_1\\ +//! w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ +//! w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) \\ +//! v_j(x) &= 1 / exp(- <\alpha_j, x> ), for j = 1, ..., H-1 +//! \f] //! //! The main difference with the mentioned paper is that the authors propose a //! Gibbs sampler in which the full conditionals are available in close form @@ -30,12 +37,6 @@ //! For more information about the class, please refer instead to base classes, //! `AbstractMixing` and `BaseMixing`. -namespace LogitSB { -struct State { - Eigen::MatrixXd regression_coeffs, precision; -}; -}; // namespace LogitSB - class LogitSBMixing : public BaseMixing { public: diff --git a/src/mixings/mixture_finite_mixing.h b/src/mixings/mixture_finite_mixing.h index c1adc3dfa..0a38752dd 100644 --- a/src/mixings/mixture_finite_mixing.h +++ b/src/mixings/mixture_finite_mixing.h @@ -12,18 +12,27 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" +namespace Mixture_Finite { +struct State { + double lambda, gamma; +}; +}; // namespace Mixture_Finite + //! Class that represents the Mixture of Finite Mixtures (MFM) [1] //! The basic idea is to take usual finite mixture model with Dirichlet weights //! and put a prior (Poisson) on the number of components. The EPPF induced by -//! MFM depends on a Dirichlet parameter 'gamma' and number V_n(t), where -//! V_n(t) depends on the Poisson rate parameter 'lambda'. -//! V_n(t) = sum_{k=1}^{inf} ( k_(t)*p_K(k) / (gamma*k)^(n) ) +//! MFM depends on a Dirichlet parameter 'gamma' and number \f$V_n(t)\f$, where +//! \f$V_n(t)\f$ depends on the Poisson rate parameter 'lambda'. +//! \f[ +//! V_n(t) = \sum_{k=1}^{\infty} ( k_(t)p_K(k) / (\gamma*k)^(n) ) +//! \f] //! Given a clustering of n elements into k clusters, each with cardinality -//! n_j, j=1, ..., k, the EPPF of the MFM gives the following probabilities for -//! the cluster membership of the (n+1)-th observation: -//! denominator = n_j + gamma / (n + gamma*(n_clust + V[n_clust+1]/V[n_clust])) -//! p(j-th cluster | ...) = (n_j + gamma) / denominator -//! p(k+1-th cluster | ...) = V[n_clust+1]/V[n_clust]*gamma / denominator +//! \f$n_j, j=1, ..., k\f$, the EPPF of the MFM gives the following +//! probabilities for the cluster membership of the (n+1)-th observation: \f[ +//! p(\text{j-th cluster} | ...) &= (n_j + \gamma) / D \\ +//! p(\text{k+1-th cluster} | ...) &= V[k+1]/V[k] \gamma / D \\ +//! D &= n_j + \gamma / (n + \gamma * (k + V[k+1]/V[k])) +//! \f] //! For numerical reasons each value of V is multiplied with a constant C //! computed as the first term of the series of V_n[0]. //! For more information about the class, please refer instead to base @@ -31,12 +40,6 @@ //! [1] "Mixture Models with a Prior on the Number of Components", J.W.Miller //! and M.T.Harrison, 2015, arXiv:1502.06241v1 -namespace Mixture_Finite { -struct State { - double lambda, gamma; -}; -}; // namespace Mixture_Finite - class MixtureFiniteMixing : public BaseMixing { diff --git a/src/mixings/pityor_mixing.h b/src/mixings/pityor_mixing.h index 5ebec6958..6dbedf569 100644 --- a/src/mixings/pityor_mixing.h +++ b/src/mixings/pityor_mixing.h @@ -12,26 +12,26 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" +namespace PitYor { +struct State { + double strength, discount; +}; +}; // namespace PitYor + //! Class that represents the Pitman-Yor process (PY) in Pitman and Yor (1997). //! The EPPF induced by the PY depends on a `strength` parameter M and a //! `discount` paramter d. //! Given a clustering of n elements into k clusters, each with cardinality -//! n_j, j=1, ..., k, the EPPF of the PY gives the following probabilities for -//! the cluster membership of the (n+1)-th observation: -//! p(j-th cluster | ...) \propto (n_j - d) -//! p(k+1-th cluster | ...) \propto M + k * d -//! +//! \f$ n_j, j=1, ..., k \f$, the EPPF of the PY gives the following +//! probabilities for the cluster membership of the (n+1)-th observation: \f[ +//! p(\text{j-th cluster} | ...) \propto (n_j - d) \\ +//! p(\text{new cluster} | ...) \propto M + k d +//! \f] //! When `discount=0`, the EPPF of the PY process coincides with the one of the //! DP with totalmass = strength. //! For more information about the class, please refer instead to base classes, //! `AbstractMixing` and `BaseMixing`. -namespace PitYor { -struct State { - double strength, discount; -}; -}; // namespace PitYor - class PitYorMixing : public BaseMixing { public: diff --git a/src/mixings/truncated_sb_mixing.h b/src/mixings/truncated_sb_mixing.h index 120a6cc20..3af68249b 100644 --- a/src/mixings/truncated_sb_mixing.h +++ b/src/mixings/truncated_sb_mixing.h @@ -12,30 +12,32 @@ #include "mixing_prior.pb.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Class that represents a truncated stick-breaking process, as shown in -//! Ishwaran and James (2001). -//! -//! A truncated stick-breaking process is a prior for weights (w_1,...,w_H) in -//! the H-1 dimensional unit simplex, and is defined as follows: -//! w_1 = v_1 -//! w_j = v_j (1 - v_1) ... (1 - v_{j-1}), for j=1, ... H-1 -//! w_H = 1 - (w_1 + w_2 + ... + w_{H-1}) -//! The v_j's are called sticks and we assume them to be independently -//! distributed as v_j ~ Beta(a_j, b_j). -//! -//! When a_j = 1 and b_j = M, the stick-breaking process is a truncation of the -//! stick-breaking representation of the DP. -//! When a_j = 1 - d and b_j = M + i*d, it is the trunctation of a PY process. -//! Its state is composed of the weights w_j in log-scale and the sticks v_j. -//! For more information about the class, please refer instead to base classes, -//! `AbstractMixing` and `BaseMixing`. - namespace TruncSB { struct State { Eigen::VectorXd sticks, logweights; }; }; // namespace TruncSB +//! Class that represents a truncated stick-breaking process, as shown in +//! Ishwaran and James (2001). +//! +//! A truncated stick-breaking process is a prior for weights +//! \f$(w_1,...,w_H)\f$ in the H-1 dimensional unit simplex, and is defined as +//! follows: \f[ +//! w_1 &= v_1\\ +//! w_j &= v_j (1 - v_1) ... (1 - v_{j-1}), \quad \text{for } j=1, ... H-1 \\ +//! w_H &= 1 - (w_1 + w_2 + ... + w_{H-1}) +//! \f] +//! The \f$v_j\f$'s are called sticks and we assume them to be independently +//! distributed as \f$v_j \sim \text{Beta}(a_j, b_j)\f$. +//! +//! When \f$a_j = 1\f$ and \f$b_j = M\f$, the stick-breaking process is a +//! truncation of the stick-breaking representation of the DP. When \f$a_j = 1 +//! - d\f$ and \f$b_j = M + id \f$, it is the trunctation of a PY process. Its +//! state is composed of the weights \f$w_j\f$ in log-scale and the sticks +//! \f$v_j\f$. For more information about the class, please refer instead to +//! base classes, `AbstractMixing` and `BaseMixing`. + class TruncatedSBMixing : public BaseMixing { public: From 86b0c2cf562e2fd3b14cc85a4f10b946223e5fea Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 18:02:18 +0200 Subject: [PATCH 311/317] addressing bruno's comments --- src/hierarchies/base_hierarchy.h | 8 ++++---- src/hierarchies/lapnig_hierarchy.h | 4 ++-- src/hierarchies/lin_reg_uni_hierarchy.h | 4 ++-- src/hierarchies/nnig_hierarchy.h | 4 ++-- src/hierarchies/nnw_hierarchy.h | 8 ++++---- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/hierarchies/base_hierarchy.h b/src/hierarchies/base_hierarchy.h index ad71a0c93..86ea08a0a 100644 --- a/src/hierarchies/base_hierarchy.h +++ b/src/hierarchies/base_hierarchy.h @@ -343,8 +343,8 @@ class BaseHierarchy : public AbstractHierarchy { virtual void initialize_state() = 0; //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf virtual double marg_lpdf(ProtoHypersPtr hier_params, @@ -359,8 +359,8 @@ class BaseHierarchy : public AbstractHierarchy { } //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @param covariate Covariate vector associated to datum //! @return The evaluation of the lpdf diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h index d4803a5c8..92b66fdd9 100644 --- a/src/hierarchies/lapnig_hierarchy.h +++ b/src/hierarchies/lapnig_hierarchy.h @@ -10,8 +10,8 @@ //! Laplace Normal-InverseGamma hierarchy for univariate data. //! This class represents a hierarchical model where data are distributed -//! according to a laplace likelihood (see the `LaplaceLikelihood` class for -//! deatils).The likelihood parameters have a Normal x InverseGamma centering +//! according to a Laplace likelihood (see the `LaplaceLikelihood` class for +//! deatils). The likelihood parameters have a Normal x InverseGamma centering //! distribution (see the `NxIGPriorModel` class for details). That is: //! \f[ //! f(x_i|\mu,\sigma^2) &= Laplace(\mu,\sqrt(\sigma^2/2))\\ diff --git a/src/hierarchies/lin_reg_uni_hierarchy.h b/src/hierarchies/lin_reg_uni_hierarchy.h index bd208ef01..7afd951f4 100644 --- a/src/hierarchies/lin_reg_uni_hierarchy.h +++ b/src/hierarchies/lin_reg_uni_hierarchy.h @@ -56,8 +56,8 @@ class LinRegUniHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @param covariate Covariate vectors associated to data //! @return The evaluation of the lpdf diff --git a/src/hierarchies/nnig_hierarchy.h b/src/hierarchies/nnig_hierarchy.h index cb6a0a545..161882ae1 100644 --- a/src/hierarchies/nnig_hierarchy.h +++ b/src/hierarchies/nnig_hierarchy.h @@ -48,8 +48,8 @@ class NNIGHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 4217913d1..6dfbebb0e 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -56,8 +56,8 @@ class NNWHierarchy }; //! Evaluates the log-marginal distribution of data in a single point - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @param datum Point which is to be evaluated //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, @@ -73,8 +73,8 @@ class NNWHierarchy //! Helper function that computes the predictive parameters for the //! multivariate t distribution from the current hyperparameter values. It is //! used to efficiently compute the log-marginal distribution of data. - //! @param hier_params Container of (prior or posterior) hyperparameter - //! values + //! @param hier_params Pointer to the container of (prior or posterior) + //! hyperparameter values //! @return A `HyperParam` object with the predictive parameters HyperParams get_predictive_t_parameters(ProtoHypersPtr hier_params) const { auto params = hier_params->nnw_state(); From 55b6c0cf0737f90dafd1555e7867b7742f118ffc Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Mon, 16 May 2022 18:03:47 +0200 Subject: [PATCH 312/317] rebuilding protos --- docs/CMakeLists.txt | 2 +- src/hierarchies/priors/fa_prior_model.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/CMakeLists.txt b/docs/CMakeLists.txt index 3318965a8..5981d2b7a 100644 --- a/docs/CMakeLists.txt +++ b/docs/CMakeLists.txt @@ -82,6 +82,6 @@ install(DIRECTORY ${SPHINX_BUILD} DESTINATION ${CMAKE_INSTALL_DOCDIR}) add_custom_target(document_bayesmix) -# add_dependencies(document_bayesmix document_protos) +add_dependencies(document_bayesmix document_protos) add_dependencies(document_bayesmix Doxygen) add_dependencies(document_bayesmix Sphinx) diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index d1689d74f..dfec2a731 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -15,7 +15,7 @@ //! \mu &\sim N_p(\tilde \mu, \psi I) \\ //! \Lambda &\sim DL(\alpha) \\ //! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ -//! \sigma^2_j &\sim IG(a,b) \quad j=1,...,p +//! \sigma^2_j &\sim IG(a,b) \quadq j=1,...,p //! \f] //! Where DL is the Dirichlet-Laplace distribution. See Bhattacharya A., Pati //! D, Pillai N.S., Dunson D.B. (2015). JASA 110(512), 1479–1490 for details. From 1f35fbbd9bcf984982550e6c8f735d082eed5f12 Mon Sep 17 00:00:00 2001 From: Mario beraha Date: Tue, 17 May 2022 08:54:01 +0200 Subject: [PATCH 313/317] typo --- src/hierarchies/priors/fa_prior_model.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index dfec2a731..d1689d74f 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -15,7 +15,7 @@ //! \mu &\sim N_p(\tilde \mu, \psi I) \\ //! \Lambda &\sim DL(\alpha) \\ //! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ -//! \sigma^2_j &\sim IG(a,b) \quadq j=1,...,p +//! \sigma^2_j &\sim IG(a,b) \quad j=1,...,p //! \f] //! Where DL is the Dirichlet-Laplace distribution. See Bhattacharya A., Pati //! D, Pillai N.S., Dunson D.B. (2015). JASA 110(512), 1479–1490 for details. From 4d7d598b4294ae68494d6804c25a14b8b9cd7681 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 17 May 2022 22:22:16 +0200 Subject: [PATCH 314/317] latex hotfix --- src/algorithms/conditional_algorithm.h | 52 +++++++++-------- src/algorithms/marginal_algorithm.h | 56 ++++++++++--------- src/hierarchies/fa_hierarchy.h | 38 +++++++------ src/hierarchies/lapnig_hierarchy.h | 35 ++++++------ src/hierarchies/lin_reg_uni_hierarchy.h | 18 +++--- src/hierarchies/nnig_hierarchy.h | 32 ++++++----- src/hierarchies/nnw_hierarchy.h | 37 ++++++------ src/hierarchies/nnxig_hierarchy.h | 32 ++++++----- src/hierarchies/priors/fa_prior_model.h | 6 +- src/hierarchies/priors/mnig_prior_model.h | 13 +++-- src/hierarchies/priors/nig_prior_model.h | 21 ++++--- src/hierarchies/priors/nw_prior_model.h | 20 ++++--- src/hierarchies/priors/nxig_prior_model.h | 13 +++-- src/hierarchies/updaters/mala_updater.h | 29 ++++++---- src/hierarchies/updaters/mnig_updater.h | 25 ++++++--- src/hierarchies/updaters/nnig_updater.h | 23 +++++--- src/hierarchies/updaters/nnw_updater.h | 23 +++++--- src/hierarchies/updaters/nnxig_updater.h | 24 +++++--- .../updaters/random_walk_updater.h | 29 ++++++---- 19 files changed, 307 insertions(+), 219 deletions(-) diff --git a/src/algorithms/conditional_algorithm.h b/src/algorithms/conditional_algorithm.h index d83d5f7af..a87fb6630 100644 --- a/src/algorithms/conditional_algorithm.h +++ b/src/algorithms/conditional_algorithm.h @@ -7,30 +7,34 @@ #include "base_algorithm.h" #include "src/collectors/base_collector.h" -//! Template class for a conditional sampler deriving from `BaseAlgorithm`. - -//! This template class implements a generic Gibbs sampling conditional -//! algorithm as the child of the `BaseAlgorithm` class. -//! A mixture model sampled from a conditional algorithm can be expressed as -//! \f[ -//! x_i | c_i, \theta_1, ..., \theta_k & \sim f(x_i|\theta_(c_i)) \\ -//! \theta_1, ..., \theta_k & \sim G_0 \\ -//! c_1, ... c_n | w_1, ..., w_k & \sim \text{Cat}(w_1, ... w_k) \\ -//! w_1, ..., w_k & \sim p(w_1, ..., w_k) -//! \f] -//! where \f$f(x | \theta_j)\f$ is a density for each value of \f$\theta_j\f$, -//! \f$c_i\f$ take values in \f${1, ..., k}\f$ and \f$w_1, ..., w_k\f$ are -//! nonnegative weights whose sum is a.s. 1, i.e. \f$p(w_1, ... w_k)\f$ is a -//! probability distribution on the k-1 dimensional unit simplex). n this -//! library, each \f$\theta_j\f$ is represented as an `Hierarchy` object (which -//! inherits from `AbstractHierarchy`), that also holds the information related -//! to the base measure `G` is (see `AbstractHierarchy`). -//! The weights (w_1, ..., w_k) are represented as a `Mixing` object, which -//! inherits from `AbstractMixing`. - -//! The state of a conditional algorithm consists of the unique values, the -//! cluster allocations and the mixture weights. The former two are stored in -//! this class, while the weights are stored in the `Mixing` object. +/** + * Template class for a conditional sampler deriving from `BaseAlgorithm`. + * + * This template class implements a generic Gibbs sampling conditional + * algorithm as the child of the `BaseAlgorithm` class. + * A mixture model sampled from a conditional algorithm can be expressed as + * + * \f[ + * x_i \mid c_i, \theta_1, \dots, \theta_k &\sim f(x_i \mid \theta_{c_i}) \\ + * \theta_1, \dots, \theta_k &\sim G_0 \\ + * c_1, \dots, c_n \mid w_1, \dots, w_k &\sim \text{Cat}(w_1, \dots, w_k) \\ + * w_1, \dots, w_k &\sim p(w_1, \dots, w_k) + * \f] + * + * where \f$ f(x \mid \theta_j) \f$ is a density for each value of \f$ \theta_j + * \f$, \f$ c_i \f$ take values in \f$ \{1, \dots, k\} \f$ and \f$ w_1, \dots, + * w_k \f$ are nonnegative weights whose sum is a.s. 1, i.e. \f$ p(w_1, ... + * w_k) \f$ is a probability distribution on the k-1 dimensional unit simplex). + * In this library, each \f$ \theta_j \f$ is represented as an `Hierarchy` + * object (which inherits from `AbstractHierarchy`), that also holds the + * information related to the base measure \f$ G \f$ is (see + * `AbstractHierarchy`). The weights \f$ (w_1, \dots, w_k) \f$ are represented + * as a `Mixing` object, which inherits from `AbstractMixing`. + * + * The state of a conditional algorithm consists of the unique values, the + * cluster allocations and the mixture weights. The former two are stored in + * this class, while the weights are stored in the `Mixing` object. + */ class ConditionalAlgorithm : public BaseAlgorithm { public: diff --git a/src/algorithms/marginal_algorithm.h b/src/algorithms/marginal_algorithm.h index d0c880e10..95fcddd92 100644 --- a/src/algorithms/marginal_algorithm.h +++ b/src/algorithms/marginal_algorithm.h @@ -8,31 +8,37 @@ #include "src/collectors/base_collector.h" #include "src/hierarchies/abstract_hierarchy.h" -//! Template class for a marginal sampler deriving from `BaseAlgorithm`. - -//! This template class implements a generic Gibbs sampling marginal algorithm -//! as the child of the `BaseAlgorithm` class. -//! A mixture model sampled from a Marginal Algorithm can be expressed as -//! \f[ -//! x_i | c_i, \theta_1, ..., \theta_k & \sim f(x_i|\theta_(c_i)) \\ -//! \theta_1, ..., \theta_k & \sim G_0 \\ -//! c_1, ... c_n & \sim EPPF(c_1, ... c_n) -//! \f] -//! where \f$f(x | \theta_j)\f$ is a density for each value of \f$\theta_j\f$ -//! and \f$c_i\f$ take values in \f${1, ..., k}\f$. Depending on the actual -//! implementation, the algorithm might require the kernel/likelihood \f$f(x | -//! \theta)\f$ and \f$G_0(phi)\f$ to be conjugagte or not. In the former case, -//! a conjugate hierarchy must be specified. In this library, each -//! \f$\theta_j\f$ is represented as an `Hierarchy` object (which inherits from -//! `AbstractHierarchy`), that also holds the information related to the base -//! measure \f$G_0\f$ is (see `AbstractHierarchy`). The EPPF is instead -//! represented as a `Mixing` object, which inherits from `AbstractMixing`. -//! -//! The state of a marginal algorithm only consists of allocations and unique -//! values. In this class of algorithms, the local lpdf estimate for a single -//! iteration is a weighted average of likelihood values corresponding to each -//! component (i.e. cluster), with the weights being based on its cardinality, -//! and of the marginal component, which depends on the specific algorithm. +/** + * Template class for a marginal sampler deriving from `BaseAlgorithm`. + * + * This template class implements a generic Gibbs sampling marginal algorithm + * as the child of the `BaseAlgorithm` class. + * A mixture model sampled from a Marginal Algorithm can be expressed as + * + * \f[ + * x_i \mid c_i, \theta_1, \dots, \theta_k &\sim f(x_i \mid \theta_{c_i}) + * \\ + * \theta_1, \dots, \theta_k &\sim G_0 \\ + * c_1, \dots, c_n &\sim EPPF(c_1, \dots, c_n) + * \f] + * + * where \f$ f(x \mid \theta_j) \f$ is a density for each value of \f$ \theta_j + * \f$ and \f$ c_i \f$ take values in \f$ {1, \dots, k} \f$. Depending on the + * actual implementation, the algorithm might require the kernel/likelihood \f$ + * f(x \mid \theta) \f$ and \f$ G_0(\phi) \f$ to be conjugagte or not. In the + * former case, a conjugate hierarchy must be specified. In this library, each + * \f$ \theta_j \f$ is represented as an `Hierarchy` object (which inherits + * from `AbstractHierarchy`), that also holds the information related to the + * base measure \f$ G_0 \f$ is (see `AbstractHierarchy`). The \f$ EPPF \f$ is + * instead represented as a `Mixing` object, which inherits from + * `AbstractMixing`. + * + * The state of a marginal algorithm only consists of allocations and unique + * values. In this class of algorithms, the local lpdf estimate for a single + * iteration is a weighted average of likelihood values corresponding to each + * component (i.e. cluster), with the weights being based on its cardinality, + * and of the marginal component, which depends on the specific algorithm. + */ class MarginalAlgorithm : public BaseAlgorithm { public: diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index c0f8ecd63..17da43564 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -8,23 +8,27 @@ #include "src/utils/distributions.h" #include "updaters/fa_updater.h" -//! Mixture of Factor Analysers hierarchy for multivariate data. -//! -//! This class represents a hierarchical model where data are distributed -//! according to a multivariate Normal likelihood with a specific factorization -//! of the covariance matrix (see the `FAHierarchy` class for details). The -//! likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma -//! centering distribution (see the `FAPriorModel` class for details). That is: -//! \f[ -//! f(x_i| \mu, \Sigma, \Lambda) &= N(\mu, \Sigma + \Lambda \Lambda^T) \\ -//! \mu &\sim N_p(\tilde \mu, \psi I) \\ -//! \Lambda &\sim DL(\alpha) \\ -//! \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ -//! \sigma^2_j &\sim IG(a,b) \quad j=1,...,p -//! \f] -//! where Lambda is the latent score matrix (size p x d with d << p) and -//! DL(alpha) is the Laplace-Dirichlet distribution. -//! See Bhattacharya et al. (2015) for further details. +/** + * Mixture of Factor Analysers hierarchy for multivariate data. + * + * This class represents a hierarchical model where data are distributed + * according to a multivariate Normal likelihood with a specific factorization + * of the covariance matrix (see the `FAHierarchy` class for details). The + * likelihood parameters have a Dirichlet-Laplace distribution x InverseGamma + * centering distribution (see the `FAPriorModel` class for details). That is: + * + * \f[ + * f(x_i \mid \mu, \Sigma, \Lambda) &= N(\mu, \Sigma + \Lambda \Lambda^T) \\ + * \mu &\sim N_p(\tilde \mu, \psi I) \\ + * \Lambda &\sim DL(\alpha) \\ + * \Sigma &= diag(\sigma^2_1, \ldots, \sigma^2_p) \\ + * \sigma^2_j &\sim IG(a,b) \quad j=1,...,p + * \f] + * + * where Lambda is the latent score matrix (size \f$ p \times d \f$ + * with \f$ d << p \f$) and \f$ DL(\alpha) \f$ is the Laplace-Dirichlet + * distribution. See Bhattacharya et al. (2015) for further details + */ class FAHierarchy : public BaseHierarchy { diff --git a/src/hierarchies/lapnig_hierarchy.h b/src/hierarchies/lapnig_hierarchy.h index 92b66fdd9..01574da19 100644 --- a/src/hierarchies/lapnig_hierarchy.h +++ b/src/hierarchies/lapnig_hierarchy.h @@ -7,22 +7,25 @@ #include "priors/nxig_prior_model.h" #include "updaters/mala_updater.h" -//! Laplace Normal-InverseGamma hierarchy for univariate data. - -//! This class represents a hierarchical model where data are distributed -//! according to a Laplace likelihood (see the `LaplaceLikelihood` class for -//! deatils). The likelihood parameters have a Normal x InverseGamma centering -//! distribution (see the `NxIGPriorModel` class for details). That is: -//! \f[ -//! f(x_i|\mu,\sigma^2) &= Laplace(\mu,\sqrt(\sigma^2/2))\\ -//! \mu &\sim N(\mu_0,\eta^2) \\ -//! \sigma^2 ~ IG(a, b) -//! \f] -//! The state is composed of mean and variance (thus the scale for the Laplace -//! distribution is \f$ \sqrt(\sigma^2 / 2)) \f$. The state hyperparameters are -//! \f$(mu_0, \sigma^2, a, b)\f$, all scalar values. Note that this hierarchy -//! is NOT conjugate, thus the marginal distribution is not available in closed -//! form. +/** + * Laplace Normal-InverseGamma hierarchy for univariate data. + * + * This class represents a hierarchical model where data are distributed + * according to a Laplace likelihood (see the `LaplaceLikelihood` class for + * deatils). The likelihood parameters have a Normal x InverseGamma centering + * distribution (see the `NxIGPriorModel` class for details). That is: + * + * \f[ + * f(x_i \mid \mu,\sigma^2) &= Laplace(\mu,\sqrt{\sigma^2/2})\\ + * \mu &\sim N(\mu_0,\eta^2) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * The state is composed of mean and variance (thus the scale for the Laplace + * distribution is \f$ \sqrt{\sigma^2/2}) \f$. The state hyperparameters are + * \f$(mu_0, \sigma^2, a, b)\f$, all scalar values. Note that this hierarchy + * is NOT conjugate, thus the marginal distribution is not available in closed + * form. + */ class LapNIGHierarchy : public BaseHierarchy { diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 6dfbebb0e..7d8bd330d 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -8,23 +8,26 @@ #include "src/utils/distributions.h" #include "updaters/nnw_updater.h" -//! Normal Normal-Wishart hierarchy for multivariate data. - -//! This class represents a hierarchy whose multivariate data -//! are distributed according to a multivariate normal likelihood (see the -//! `MultiNormLikelihood` for details). The likelihood parameters have a -//! Normal-Wishart centering distribution (see the `NWPriorModel` class for -//! details). That is: -//! \f[ -//! f(x_i|\mu,\Sigma) &= N(\mu,\Sigma^{-1}) \\ -//! (\mu,\Sigma) &\sim NW(\mu_0, \lambda, \Psi_0, \nu_0) -//! \f] -//! The state is composed of mean and precision matrix. The Cholesky factor and -//! log-determinant of the latter are also included in the container for -//! efficiency reasons. The state's hyperparameters are \f$(\mu_0, \lambda, -//! \Psi_0, \nu_0)\f$, which are respectively vector, scalar, matrix, and -//! scalar. Note that this hierarchy is conjugate, thus the marginal -//! distribution is available in closed form. +/** + * Normal Normal-Wishart hierarchy for multivariate data. + * + * This class represents a hierarchy whose multivariate data + * are distributed according to a multivariate normal likelihood (see the + * `MultiNormLikelihood` for details). The likelihood parameters have a + * Normal-Wishart centering distribution (see the `NWPriorModel` class for + * details). That is: + * + * \f[ + * f(\bm{x}_i \mid \bm{\mu},\Sigma) &= N_d(\bm{\mu},\Sigma^{-1}) \\ + * (\bm{\mu},\Sigma) &\sim NW(\mu_0, \lambda, \Psi_0, \nu_0) + * \f] + * The state is composed of mean and precision matrix. The Cholesky factor and + * log-determinant of the latter are also included in the container for + * efficiency reasons. The state's hyperparameters are \f$(\mu_0, \lambda, + * \Psi_0, \nu_0)\f$, which are respectively vector, scalar, matrix, and + * scalar. Note that this hierarchy is conjugate, thus the marginal + * distribution is available in closed form + */ class NNWHierarchy : public BaseHierarchy { diff --git a/src/hierarchies/nnxig_hierarchy.h b/src/hierarchies/nnxig_hierarchy.h index 1901083c6..20aecebdd 100644 --- a/src/hierarchies/nnxig_hierarchy.h +++ b/src/hierarchies/nnxig_hierarchy.h @@ -7,19 +7,25 @@ #include "priors/nxig_prior_model.h" #include "updaters/nnxig_updater.h" -//! Semi-conjugate Normal Normal x InverseGamma hierarchy for univariate data. -//! -//! This class represents a hierarchical model where data are distributed -//! according to a Normal likelihood (see the `UniNormLikelihood` class for -//! details). The likelihood parameters have a Normal x InverseGamma centering -//! distribution (see the `NxIGPriorModel` class for details). That is: -//! f(x_i|mu,sig^2) = N(mu,sig^2) -//! mu ~ N(mu0, sig0^2) -//! sig^2 ~ IG(alpha0, beta0) -//! The state is composed of mean and variance. The state hyperparameters are -//! (mu_0, sig0^2, alpha0, beta0), all scalar values. Note that this hierarchy -//! is NOT conjugate, meaning that the marginal distribution is not available -//! in closed form. +/** + * Semi-conjugate Normal Normal x InverseGamma hierarchy for univariate data. + * + * This class represents a hierarchical model where data are distributed + * according to a Normal likelihood (see the `UniNormLikelihood` class for + * details). The likelihood parameters have a Normal x InverseGamma centering + * distribution (see the `NxIGPriorModel` class for details). That is: + * + * \f[ + * f(x_i \mid \mu,\sigma^2) &= N(\mu,\sigma^2) \\ + * \mu &\sim N(\mu_0, \eta^2) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * The state is composed of mean and variance. The state hyperparameters are + * \f$ (\mu_0, \eta^2, a, b) \f$, all scalar values. Note that this hierarchy + * is NOT conjugate, meaning that the marginal distribution is not available + * in closed form + */ class NNxIGHierarchy : public BaseHierarchy { diff --git a/src/hierarchies/priors/fa_prior_model.h b/src/hierarchies/priors/fa_prior_model.h index d1a0297b0..9245f4885 100644 --- a/src/hierarchies/priors/fa_prior_model.h +++ b/src/hierarchies/priors/fa_prior_model.h @@ -14,11 +14,11 @@ * A priormodel for the factor analyzers likelihood, that is * * \f[ - * \mu &\sim N_p(\tilde{\mu}, \psi I) \\ + * \bm{\mu} &\sim N_p(\tilde{\bm{\mu}}, \psi I) \\ * \Lambda &\sim DL(\alpha) \\ * \Sigma &= \mathrm{diag}(\sigma^2_1, \ldots, \sigma^2_p) \\ - * \sigma^2_j &\sim IG(a,b) \quad j=1,...,p - * \f] + * \sigma^2_j &\stackrel{\small\mathrm{iid}}{\sim} InvGamma(a,b) \quad + * j=1,...,p \f] * * Where \f$ DL \f$ is the Dirichlet-Laplace distribution. * See Bhattacharya A., Pati D., Pillai N.S., Dunson D.B. (2015). diff --git a/src/hierarchies/priors/mnig_prior_model.h b/src/hierarchies/priors/mnig_prior_model.h index 1202ab551..9a2e2ddd5 100644 --- a/src/hierarchies/priors/mnig_prior_model.h +++ b/src/hierarchies/priors/mnig_prior_model.h @@ -10,11 +10,14 @@ #include "hyperparams.h" #include "src/utils/rng.h" -//! A conjugate prior model for the scalar linear regression likelihood, i.e. -//! \f[ -//! \beta | \sigma^2 & \sim N_p(\mu, \sigma^2 \Lambda^-1) \\ -//! \sigma^2 & \sim IG(a,b) -//! \f] +/** + * A conjugate prior model for the scalar linear regression likelihood, i.e. + * + * \f[ + * \bm{\beta} \mid \sigma^2 & \sim N_p(\bm{\mu}, \sigma^2 \Lambda^{-1}) \\ + * \sigma^2 & \sim InvGamma(a,b) + * \f] + */ class MNIGPriorModel : public BasePriorModel { protected: double step_size; diff --git a/src/hierarchies/updaters/mnig_updater.h b/src/hierarchies/updaters/mnig_updater.h index 17de64502..ec7a8c65f 100644 --- a/src/hierarchies/updaters/mnig_updater.h +++ b/src/hierarchies/updaters/mnig_updater.h @@ -5,15 +5,22 @@ #include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h" #include "src/hierarchies/priors/mnig_prior_model.h" -//! Updater specific for the `UniLinRegLikelihood` used in combination -//! with `MNIGPriorModel`, that is the model -//! y_i | reg_coeffs, var ~ N(reg_coeffs^T x_i, var) -//! reg_coeffs | var ~ N_p(mu0, sigsq * V^{-1}) -//! var ~ InvGamma(a, b) -//! -//! It exploits the conjugacy of the model to sample the full conditional of -//! (reg_coeffs, var) by calling `MNIGPriorModel::sample` with updated -//! parameters +/** + * Updater specific for the `UniLinRegLikelihood` used in combination + * with `MNIGPriorModel`, that is the model + * + * \f[ + * y_i \mid \bm{\beta}, \sigma^2 &\stackrel{\small\mathrm{iid}}{\sim} + * N(\bm{\beta}^T\bm{x}_i, \sigma^2) \\ + * \bm{\beta} \mid \sigma^2 &\sim N_p(\mu_{0}, \sigma^2 \mathbf{V}^{-1}) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * It exploits the conjugacy of the model to sample the full conditional of + * \f$ (\bm{\beta}, \sigma^2) \f$ by calling `MNIGPriorModel::sample` with + * updated parameters + */ + class MNIGUpdater : public SemiConjugateUpdater { public: diff --git a/src/hierarchies/updaters/nnig_updater.h b/src/hierarchies/updaters/nnig_updater.h index 058a03fe5..5866735ba 100644 --- a/src/hierarchies/updaters/nnig_updater.h +++ b/src/hierarchies/updaters/nnig_updater.h @@ -5,14 +5,21 @@ #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/hierarchies/priors/nig_prior_model.h" -//! Updater specific for the `UniNormLikelihood` used in combination -//! with `NIGPriorModel`, that is the model -//! y_i | mu, sigsq ~ N(mu, sigsq) -//! mu | sigsq ~ N(mu0, sigsq / lambda) -//! sigsq ~ InvGamma(a, b) -//! -//! It exploits the conjugacy of the model to sample the full conditional of -//! (mu, sigsq) by calling `NIGPriorModel::sample` with updated parameters +/** + * Updater specific for the `UniNormLikelihood` used in combination + * with `NIGPriorModel`, that is the model + * + * \f[ + * y_i \mid \mu, \sigma^2 &\stackrel{\small\mathrm{iid}}{\sim} N(\mu, + * \sigma^2) \\ + * \mu \mid \sigma^2 &\sim N(\mu_0, \sigma^2 / \lambda) \\ + * \sigma^2 &\sim InvGamma(a, b) + * \f] + * + * It exploits the conjugacy of the model to sample the full conditional of + * \f$ (\mu, \sigma^2) \f$ by calling `NIGPriorModel::sample` with updated + * parameters + */ class NNIGUpdater : public SemiConjugateUpdater { diff --git a/src/hierarchies/updaters/nnw_updater.h b/src/hierarchies/updaters/nnw_updater.h index d7e07c48c..b7877274d 100644 --- a/src/hierarchies/updaters/nnw_updater.h +++ b/src/hierarchies/updaters/nnw_updater.h @@ -5,14 +5,21 @@ #include "src/hierarchies/likelihoods/multi_norm_likelihood.h" #include "src/hierarchies/priors/nw_prior_model.h" -//! Updater specific for the `MultiNormLikelihood` used in combination -//! with `NWPriorModel`, that is the model -//! y_i | mu, Sigma ~ Nd(mu, Sigma) -//! mu | Sigma ~ N_d(mu0, sigsq / lambda) -//! Sigma^{-1} ~ Wishart(nu, Psi) -//! -//! It exploits the conjugacy of the model to sample the full conditional of -//! (mu, Sigma) by calling `NWPriorModel::sample` with updated parameters +/** + * Updater specific for the `MultiNormLikelihood` used in combination + * with `NWPriorModel`, that is the model + * + * \f[ + * y_i \mid \bm{\mu}, \Sigma &\stackrel{\small\mathrm{iid}}{\sim} + * N_d(\bm{mu}, \Sigma) \\ + * \bm{\mu} \mid \Sigma &\sim N_d(\bm{\mu}_0, \Sigma / \lambda) \\ + * \Sigma^{-1} &\sim Wishart(\nu, \Psi) + * \f] + * + * It exploits the conjugacy of the model to sample the full conditional of + * \f$ (\bm{\mu}, \Sigma) \f$ by calling `NWPriorModel::sample` with updated + * parameters. + */ class NNWUpdater : public SemiConjugateUpdater { diff --git a/src/hierarchies/updaters/nnxig_updater.h b/src/hierarchies/updaters/nnxig_updater.h index 4b752e1f6..195b8c44f 100644 --- a/src/hierarchies/updaters/nnxig_updater.h +++ b/src/hierarchies/updaters/nnxig_updater.h @@ -5,15 +5,21 @@ #include "src/hierarchies/likelihoods/uni_norm_likelihood.h" #include "src/hierarchies/priors/nxig_prior_model.h" -//! Updater specific for the `UniNormLikelihood` used in combination -//! with `NxIGPriorModel`, that is the model -//! \f[ -//! y_i | \mu, \sigma^2 &\sim N(\mu, \sigma^2) \\ -//! \mu &\sim N(\mu_0, \eta^2) \\ -//! \sigma^2 & \sim IG(a,b) -//! \f] -//! It exploits the semi-conjugacy of the model to sample the full conditional -//! of (mu, sigsq) by calling `NxIGPriorModel::sample` with updated parameters +/** + * Updater specific for the `UniNormLikelihood` used in combination + * with `NxIGPriorModel`, that is the model + * + * \f[ + * y_i \mid \mu, \sigma^2 &\stackrel{\small\mathrm{iid}}{\sim} N(\mu, + * \sigma^2) \\ + * \mu &\sim N(\mu_0, \eta^2) \\ + * \sigma^2 & \sim InvGamma(a,b) + * \f] + * + * It exploits the semi-conjugacy of the model to sample the full conditional + * of \f$ (\mu, \sigma^2) \f$ by calling `NxIGPriorModel::sample` with updated + * parameters + */ class NNxIGUpdater : public SemiConjugateUpdater { diff --git a/src/hierarchies/updaters/random_walk_updater.h b/src/hierarchies/updaters/random_walk_updater.h index dae7653ea..cb3076206 100644 --- a/src/hierarchies/updaters/random_walk_updater.h +++ b/src/hierarchies/updaters/random_walk_updater.h @@ -3,17 +3,24 @@ #include "metropolis_updater.h" -//! Metropolis-Hastings updater using an isotropic proposal function -//! centered in the current value of the parameters (unconstrained). -//! This class requires that the Hierarchy's state implements -//! the `get_unconstrained()`, `set_from_unconstrained()` and -//! `log_det_jac()` functions. -//! -//! Given the current value of the unconstrained parameters x, a new -//! value is proposed from -//! x_new ~ N(x_new, step_size * I) -//! and then either accepted (in which case the hierarchy's state is -//! set to x_new) or rejected. +/** + * Metropolis-Hastings updater using an isotropic proposal function + * centered in the current value of the parameters (unconstrained). + * This class requires that the Hierarchy's state implements + * the `get_unconstrained()`, `set_from_unconstrained()` and + * `log_det_jac()` functions. + * + * Given the current value of the unconstrained parameters \f$ x \f$, a new + * value is proposed from + * + * \f[ + * x_{new} \sim N(x, step\_size \cdot I) + * \f] + * + * and then either accepted (in which case the hierarchy's state is + * set to \f$ x_{new} \f$) or rejected. + */ + class RandomWalkUpdater : public MetropolisUpdater { protected: double step_size; From 056c63915848fb8230bb2e6ba34446ae51a775e8 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 17 May 2022 22:22:34 +0200 Subject: [PATCH 315/317] latex hotfix --- docs/hierarchies.rst | 4 ++-- docs/updaters.rst | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/hierarchies.rst b/docs/hierarchies.rst index f5272c1fd..c6dedcab9 100644 --- a/docs/hierarchies.rst +++ b/docs/hierarchies.rst @@ -6,9 +6,9 @@ Hierarchies In our algorithms, we store a vector of hierarchies, each of which represent a parameter :math:`\theta_h`. The hierarchy implements all the methods needed to update :math:`\theta_h`: sampling from the prior distribution :math:`P_0`, the full-conditional distribution (given the data {:math:`y_i` such that :math:`c_i = h`} ) and so on. -In BayesMix, each choice of :math:`G_0` is implemented in a different ``PriorModel`` object and each choice of :math:k(\cdot \mid \cdot)` in a ``Likelihood`` object, so that it is straightforward to create a new ``Hierarchy`` using one of the already implemented priors or likelihoods. +In BayesMix, each choice of :math:`G_0` is implemented in a different ``PriorModel`` object and each choice of :math:`k(\cdot \mid \cdot)` in a ``Likelihood`` object, so that it is straightforward to create a new ``Hierarchy`` using one of the already implemented priors or likelihoods. The sampling from the full conditional of :math:`\theta_h` is performed in an ``Updater`` class. -`State` classes are used to store parameters ``\theta_h`s of every mixture component. +`State` classes are used to store parameters :math:`\theta_h` s of every mixture component. Their main purpose is to handle serialization and de-serialization of the state .. toctree:: diff --git a/docs/updaters.rst b/docs/updaters.rst index 9e00718ce..ac3fb19c0 100644 --- a/docs/updaters.rst +++ b/docs/updaters.rst @@ -54,7 +54,6 @@ Updater Classes .. doxygenclass:: MalaUpdater :project: bayesmix :members: - .. doxygenclass:: NNIGUpdater :project: bayesmix :members: From 5463cd2a902b56556c39d27b07b42d92d19ba619 Mon Sep 17 00:00:00 2001 From: Matteo Gianella Date: Tue, 17 May 2022 22:23:26 +0200 Subject: [PATCH 316/317] latex hotfix --- src/algorithms/conditional_algorithm.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/conditional_algorithm.h b/src/algorithms/conditional_algorithm.h index a87fb6630..f7bed0391 100644 --- a/src/algorithms/conditional_algorithm.h +++ b/src/algorithms/conditional_algorithm.h @@ -23,7 +23,7 @@ * * where \f$ f(x \mid \theta_j) \f$ is a density for each value of \f$ \theta_j * \f$, \f$ c_i \f$ take values in \f$ \{1, \dots, k\} \f$ and \f$ w_1, \dots, - * w_k \f$ are nonnegative weights whose sum is a.s. 1, i.e. \f$ p(w_1, ... + * w_k \f$ are nonnegative weights whose sum is a.s. 1, i.e. \f$ p(w_1, \dots, * w_k) \f$ is a probability distribution on the k-1 dimensional unit simplex). * In this library, each \f$ \theta_j \f$ is represented as an `Hierarchy` * object (which inherits from `AbstractHierarchy`), that also holds the From 4591a44a7085d18b3486f5b698d795ba5e188e51 Mon Sep 17 00:00:00 2001 From: brunoguindani Date: Wed, 18 May 2022 10:18:17 +0200 Subject: [PATCH 317/317] Comments cleanup and README updates --- CMakeLists.txt | 7 -- README.md | 9 +- docs/conf.py | 2 - src/hierarchies/fa_hierarchy.h | 3 +- .../likelihoods/multi_norm_likelihood.cc | 4 +- .../likelihoods/states/CMakeLists.txt | 2 +- src/hierarchies/nnw_hierarchy.h | 1 + src/hierarchies/priors/fa_prior_model.cc | 2 - .../updaters/semi_conjugate_updater.h | 2 +- src/utils/covariates_getter.h | 2 - test/hierarchies.cc | 38 --------- test/likelihoods.cc | 2 +- test/lpdf.cc | 84 ------------------- test/write_proto.cc | 24 ------ 14 files changed, 10 insertions(+), 172 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a4d7cea61..e79f07f43 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,7 +171,6 @@ if (NOT DISABLE_DOCS) add_subdirectory(docs) endif() -# add_subdirectory(examples) if (NOT DISABLE_PLOTS) include(FetchContent) @@ -210,9 +209,3 @@ endif() if (NOT DISABLE_EXAMPLES) add_subdirectory(examples) endif() - -# Test MH updater -# add_executable(test_mh $ test_mh_updater.cpp) -# target_include_directories(test_mh PUBLIC ${INCLUDE_PATHS}) -# target_link_libraries(test_mh PUBLIC ${LINK_LIBRARIES}) -# target_compile_options(test_mh PUBLIC ${COMPILE_OPTIONS}) diff --git a/README.md b/README.md index 019625c81..7cb6972a4 100644 --- a/README.md +++ b/README.md @@ -16,13 +16,7 @@ Current state of the software: --> -where P is either the Dirichlet process or the Pitman--Yor process - -- We currently support univariate and multivariate location-scale mixture of Gaussian densities - -- Inference is carried out using algorithms such as Algorithm 2 in [Neal (2000)](http://www.stat.columbia.edu/npbayes/papers/neal_sampling.pdf) - -- Serialization of the MCMC chains is possible using Google's [Protocol Buffers](https://developers.google.com/protocol-buffers) aka `protobuf` +For descriptions of the models supported in our library, discussion of software design, and examples, please refer to the following paper: https://arxiv.org/abs/2205.08144 # Installation @@ -101,3 +95,4 @@ Documentation is available at https://bayesmix.readthedocs.io. # Contributions are welcome! Please check out [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to collaborate with us. +You can also head to our [issues page](https://github.com/bayesmix-dev/bayesmix/issues) to check for useful enhancements needed. diff --git a/docs/conf.py b/docs/conf.py index af16e0bdc..9acbf232a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,8 +57,6 @@ def configureDoxyfile(input_dir, output_dir): html_theme = 'haiku' -# html_static_path = ['_static'] - highlight_language = 'cpp' imgmath_latex = 'latex' diff --git a/src/hierarchies/fa_hierarchy.h b/src/hierarchies/fa_hierarchy.h index 17da43564..fc7ce3079 100644 --- a/src/hierarchies/fa_hierarchy.h +++ b/src/hierarchies/fa_hierarchy.h @@ -62,10 +62,11 @@ class FAHierarchy } }; -//! Empirical-Bayes hyperparameters initialization for the FA HIerarchy. +//! Empirical-Bayes hyperparameters initialization for the FAHierarchy. //! Sets the hyperparameters in `hier` starting from the data on which the user //! wants to fit the model. inline void set_fa_hyperparams_from_data(FAHierarchy* hier) { + // TODO test this function auto dataset_ptr = std::static_pointer_cast(hier->get_likelihood()) ->get_dataset(); diff --git a/src/hierarchies/likelihoods/multi_norm_likelihood.cc b/src/hierarchies/likelihoods/multi_norm_likelihood.cc index ae51f3d47..f0cfae90d 100644 --- a/src/hierarchies/likelihoods/multi_norm_likelihood.cc +++ b/src/hierarchies/likelihoods/multi_norm_likelihood.cc @@ -12,8 +12,8 @@ double MultiNormLikelihood::compute_lpdf( void MultiNormLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, bool add) { - // Check if dim is not defined yet (usually not happens if hierarchy is - // initialized) + // Check if dim is not defined yet (this usually doesn't happen if the + // hierarchy is initialized) if (!dim) set_dim(datum.size()); // Updates if (add) { diff --git a/src/hierarchies/likelihoods/states/CMakeLists.txt b/src/hierarchies/likelihoods/states/CMakeLists.txt index 4f850a603..933c337b4 100644 --- a/src/hierarchies/likelihoods/states/CMakeLists.txt +++ b/src/hierarchies/likelihoods/states/CMakeLists.txt @@ -4,5 +4,5 @@ target_sources(bayesmix PUBLIC uni_ls_state.h multi_ls_state.h uni_lin_reg_ls_state.h - # fa_state.h + fa_state.h ) diff --git a/src/hierarchies/nnw_hierarchy.h b/src/hierarchies/nnw_hierarchy.h index 7d8bd330d..2cf36464a 100644 --- a/src/hierarchies/nnw_hierarchy.h +++ b/src/hierarchies/nnw_hierarchy.h @@ -65,6 +65,7 @@ class NNWHierarchy //! @return The evaluation of the lpdf double marg_lpdf(ProtoHypersPtr hier_params, const Eigen::RowVectorXd &datum) const override { + // TODO check Bayes rule for this hierarchy HyperParams pred_params = get_predictive_t_parameters(hier_params); Eigen::VectorXd diag = pred_params.scale_chol.diagonal(); double logdet = 2 * log(diag.array()).sum(); diff --git a/src/hierarchies/priors/fa_prior_model.cc b/src/hierarchies/priors/fa_prior_model.cc index d75442795..fa402b1d1 100644 --- a/src/hierarchies/priors/fa_prior_model.cc +++ b/src/hierarchies/priors/fa_prior_model.cc @@ -7,8 +7,6 @@ double FAPriorModel::lpdf(const google::protobuf::Message &state_) { // Proto2Eigen conversion Eigen::VectorXd mu = bayesmix::to_eigen(state.mu()); Eigen::VectorXd psi = bayesmix::to_eigen(state.psi()); - - // Eigen::MatrixXd eta = bayesmix::to_eigen(state.eta()); Eigen::MatrixXd lambda = bayesmix::to_eigen(state.lambda()); // Initialize lpdf value diff --git a/src/hierarchies/updaters/semi_conjugate_updater.h b/src/hierarchies/updaters/semi_conjugate_updater.h index 41e3bc23a..5609bf1b8 100644 --- a/src/hierarchies/updaters/semi_conjugate_updater.h +++ b/src/hierarchies/updaters/semi_conjugate_updater.h @@ -66,7 +66,7 @@ void SemiConjugateUpdater::draw( auto& likecast = downcast_likelihood(like); auto& priorcast = downcast_prior(prior); // Sample from the full conditional of a semi-conjugate hierarchy - bool set_card = true; /*, use_post_hypers=true;*/ + bool set_card = true; if (likecast.get_card() == 0) { likecast.set_state(priorcast.sample(), !set_card); } else { diff --git a/src/utils/covariates_getter.h b/src/utils/covariates_getter.h index 46adb9802..530590182 100644 --- a/src/utils/covariates_getter.h +++ b/src/utils/covariates_getter.h @@ -2,8 +2,6 @@ #define BAYESMIX_SRC_UTILS_COVARIATES_GETTER_H #include -// #include "src/hierarchies/likelihoods/abstract_likelihood.h" -// #include "src/hierarchies/priors/abstract_prior_model.h" class covariates_getter { protected: diff --git a/test/hierarchies.cc b/test/hierarchies.cc index db3090f0c..0a4a94c25 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -335,44 +335,6 @@ TEST(fa_hierarchy, draw) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } -// TEST(fahierarchy, draw_auto) { -// auto hier = std::make_shared(); -// bayesmix::FAPrior prior; -// Eigen::VectorXd mutilde(0); -// bayesmix::Vector mutilde_proto; -// bayesmix::to_proto(mutilde, &mutilde_proto); -// int q = 2; -// double phi = 1.0; -// double alpha0 = 5.0; -// Eigen::VectorXd beta(0); -// bayesmix::Vector beta_proto; -// bayesmix::to_proto(beta, &beta_proto); -// Eigen::MatrixXd dataset(5, 5); -// dataset << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, -// 19, -// 20, 1, 5, 7, 8, 9; -// hier->set_dataset(&dataset); -// *prior.mutable_fixed_values()->mutable_mutilde() = mutilde_proto; -// prior.mutable_fixed_values()->set_phi(phi); -// prior.mutable_fixed_values()->set_alpha0(alpha0); -// prior.mutable_fixed_values()->set_q(q); -// *prior.mutable_fixed_values()->mutable_beta() = beta_proto; -// hier->get_mutable_prior()->CopyFrom(prior); -// hier->initialize(); - -// auto hier2 = hier->clone(); -// hier2->sample_prior(); - -// bayesmix::AlgorithmState out; -// bayesmix::AlgorithmState::ClusterState* clusval = -// out.add_cluster_states(); bayesmix::AlgorithmState::ClusterState* clusval2 -// = out.add_cluster_states(); hier->write_state_to_proto(clusval); -// hier2->write_state_to_proto(clusval2); - -// ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()) -// << clusval->DebugString() << clusval2->DebugString(); -// } - TEST(fa_hierarchy, sample_given_data) { auto hier = std::make_shared(); bayesmix::FAPrior prior; diff --git a/test/likelihoods.cc b/test/likelihoods.cc index d0fccc775..4cb2a7d5e 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -380,7 +380,7 @@ TEST(laplace_likelihood, eval_lpdf_unconstrained) { lpdf += like->lpdf(data.row(i)); } - like->set_dataset(&data); // Questa cosa è sempre garantita?? + like->set_dataset(&data); double clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); ASSERT_DOUBLE_EQ(lpdf, clus_lpdf); diff --git a/test/lpdf.cc b/test/lpdf.cc index 0f41f69ba..fe1cde610 100644 --- a/test/lpdf.cc +++ b/test/lpdf.cc @@ -7,7 +7,6 @@ #include "algorithm_state.pb.h" #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" -// #include "src/hierarchies/nnw_hierarchy.h" #include "src/utils/proto_utils.h" TEST(lpdf, nnig) { @@ -54,89 +53,6 @@ TEST(lpdf, nnig) { ASSERT_DOUBLE_EQ(sum, marg); } -// TEST(lpdf, nnw) { // TODO -// using namespace stan::math; -// NNWHierarchy hier; -// bayesmix::NNWPrior hier_prior; -// Eigen::Vector2d mu0; mu0 << 5.5, 5.5; -// bayesmix::Vector mu0_proto; -// bayesmix::to_proto(mu0, &mu0_proto); -// double lambda0 = 0.2; -// double nu0 = 5.0; -// Eigen::Matrix2d tau0 = Eigen::Matrix2d::Identity() / nu0; -// bayesmix::Matrix tau0_proto; -// bayesmix::to_proto(tau0, &tau0_proto); -// *hier_prior.mutable_fixed_values()->mutable_mean() = mu0_proto; -// hier_prior.mutable_fixed_values()->set_var_scaling(lambda0); -// hier_prior.mutable_fixed_values()->set_deg_free(nu0); -// *hier_prior.mutable_fixed_values()->mutable_scale() = tau0_proto; -// hier.set_prior(hier_prior); -// hier.initialize(); -// -// Eigen::VectorXd mu = mu0; -// Eigen::MatrixXd tau = lambda0 * Eigen::Matrix2d::Identity(); -// -// Eigen::RowVectorXd datum(2); -// datum << 4.5, 4.5; -// -// // Compute prior parameters -// Eigen::MatrixXd tau_pr = lambda0 * tau0; -// -// // Compute posterior parameters -// double mu_n = (lambda0 * mu0 + datum(0)) / (lambda0 + 1); -// double alpha_n = alpha0 + 0.5; -// double lambda_n = lambda0 + 1; -// double nu_n = nu0 + 0.5; -// Eigen::VectorXd mu_n = -// (lambda0 * mu0 + datum.transpose()) / (lambda0 + 1); -// Eigen::MatrixXd tau_temp = -// stan::math::inverse_spd(tau0) + (0.5 * lambda0 / (lambda0 + 1)) * -// (datum.transpose() - mu0) * -// (datum - mu0.transpose()); -// Eigen::MatrixXd tau_n = stan::math::inverse_spd(tau_temp); -// Eigen::MatrixXd tau_post = lambda_n * tau_n; -// -// // Compute pieces -// double prior1 = stan::math::wishart_lpdf(tau, nu0, tau0); -// double prior2 = stan::math::multi_normal_prec_lpdf(mu, mu0, tau_pr); -// double prior = prior1 + prior2; -// double like = hier.get_like_lpdf(datum); -// double post1 = stan::math::wishart_lpdf(tau, nu_n, tau_post); -// double post2 = stan::math::multi_normal_prec_lpdf(mu, mu0, tau_post); -// double post = post1 + post2; -// -// // Bayes: logmarg(x) = logprior(phi) + loglik(x|phi) - logpost(phi|x) -// double sum = prior + like - post; -// double marg = hier.prior_pred_lpdf(false, datum); -// -// // Compute logdet's -// Eigen::MatrixXd tauchol0 = -// Eigen::LLT(tau0).matrixL().transpose(); -// double logdet0 = 2 * log(tauchol0.diagonal().array()).sum(); -// Eigen::MatrixXd tauchol_n = -// Eigen::LLT(tau_n).matrixL().transpose(); -// double logdet_n = 2 * log(tauchol_n.diagonal().array()).sum(); -// -// // lmgamma(dim, x) -// int dim = 2; -// double marg_murphy = lmgamma(dim, 0.5 * nu_n) + 0.5 * nu_n * logdet_n + -// 0.5 * dim * log(lambda0) + dim * NEG_LOG_SQRT_TWO_PI -// - lmgamma(dim, 0.5 * nu0) - 0.5 * nu0 * logdet0 - 0.5 -// * dim * log(lambda_n); -// -// // std::cout << "prior1=" << prior1 << std::endl; -// // std::cout << "prior2=" << prior2 << std::endl; -// // std::cout << "prior =" << prior << std::endl; -// // std::cout << "like =" << like << std::endl; -// // std::cout << "post1 =" << post1 << std::endl; -// // std::cout << "post2 =" << post2 << std::endl; -// // std::cout << "post =" << post << std::endl; -// std::cout << "sum =" << sum << std::endl; -// std::cout << "marg =" << marg << std::endl; -// std::cout << "murphy=" << marg_murphy << std::endl; -// ASSERT_DOUBLE_EQ(marg, marg_murphy); -// } - TEST(lpdf, lin_reg_uni) { // Create hierarchy objects LinRegUniHierarchy hier; diff --git a/test/write_proto.cc b/test/write_proto.cc index 3be320ec9..da866e40a 100644 --- a/test/write_proto.cc +++ b/test/write_proto.cc @@ -3,7 +3,6 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" #include "src/hierarchies/nnig_hierarchy.h" -// #include "src/hierarchies/nnw_hierarchy.h" #include "src/utils/proto_utils.h" TEST(set_state, uni_ls) { @@ -46,26 +45,3 @@ TEST(write_proto, uni_ls) { ASSERT_EQ(mean, out_mean); ASSERT_EQ(var, out_var); } - -// TEST(set_state, multi_ls) { -// Eigen::VectorXd mean = Eigen::VectorXd::Ones(5); -// Eigen::MatrixXd prec = Eigen::MatrixXd::Identity(5, 5); -// prec(1, 1) = 10.0; - -// bayesmix::MultiLSState curr; -// bayesmix::to_proto(mean, curr.mutable_mean()); -// bayesmix::to_proto(prec, curr.mutable_prec()); - -// ASSERT_EQ(curr.mean().data(0), 1.0); -// ASSERT_EQ(curr.prec().data(0), 1.0); -// ASSERT_EQ(curr.prec().data(6), 10.0); - -// bayesmix::AlgorithmState::ClusterState clusval_in; -// clusval_in.mutable_multi_ls_state()->CopyFrom(curr); -// NNWHierarchy cluster; -// cluster.set_state_from_proto(clusval_in); - -// ASSERT_EQ(curr.mean().data(0), cluster.get_state().mean(0)); -// ASSERT_EQ(curr.prec().data(0), cluster.get_state().prec(0, 0)); -// ASSERT_EQ(curr.prec().data(6), cluster.get_state().prec(1, 1)); -// }