Skip to content

Commit

Permalink
actually getting posterior hypers
Browse files Browse the repository at this point in the history
  • Loading branch information
mberaha committed Jul 8, 2022
1 parent 07cc624 commit 8ddf06a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_BUILD_TYPE Release)

set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -funroll-loops -ftree-vectorize")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -funroll-loops -ftree-vectorize -Wno-deprecated")
set(CMAKE_FIND_PACKAGE_PREFER_CONFIG TRUE)

# Require PkgConfig
Expand Down
4 changes: 2 additions & 2 deletions src/hierarchies/base_hierarchy.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ class BaseHierarchy : public AbstractHierarchy {
double conditional_pred_lpdf(const Eigen::RowVectorXd &datum,
const Eigen::RowVectorXd &covariate =
Eigen::RowVectorXd(0)) const override {
return get_marg_lpdf(updater->compute_posterior_hypers(*like, *prior),
datum, covariate);
return get_marg_lpdf(updater->get_posterior_hypers(*like, *prior), datum,
covariate);
}

//! Evaluates the log-prior predictive distr. of data in a grid of points
Expand Down
30 changes: 24 additions & 6 deletions src/hierarchies/updaters/abstract_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,34 @@ class AbstractUpdater {
}

//! Stores the posterior hyperparameters in an appropriate container
virtual void save_posterior_hypers(ProtoHypersPtr post_hypers_) {
void save_posterior_hypers(ProtoHypersPtr post_hypers_) {
if (!is_conjugate()) {
throw(
std::runtime_error("Cannot call save_posterior_hypers() from a "
"non-(semi)conjugate updater"));
throw std::runtime_error(
"Cannot call save_posterior_hypers() from a "
"non-(semi)conjugate updater");
} else {
throw(std::runtime_error(
"save_posterior_hypers() not implemented for this updater"));
posterior_hypers = post_hypers_;
}
}

virtual ProtoHypersPtr get_posterior_hypers(AbstractLikelihood &like,
AbstractPriorModel &prior) {
if (!is_conjugate()) {
throw std::runtime_error(
"Cannot call get_posterior_hypers() from a "
"non-(semi)conjugate updater");
} else {
if (posterior_hypers == nullptr) {
posterior_hypers = compute_posterior_hypers(like, prior);
}

return posterior_hypers;
}
}

protected:
bool saved_posterior_hypers = false;
ProtoHypersPtr posterior_hypers = nullptr;
};

#endif // BAYESMIX_HIERARCHIES_UPDATERS_ABSTRACT_UPDATER_H_
10 changes: 0 additions & 10 deletions src/hierarchies/updaters/semi_conjugate_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ class SemiConjugateUpdater : public AbstractUpdater {

//! Used by algorithms such as Neal3 and SplitMerge
//! It stores the hyperparameters computed by `compute_posterior_hypers`
void save_posterior_hypers(ProtoHypersPtr post_hypers_) override;

protected:
Likelihood& downcast_likelihood(AbstractLikelihood& like_);
PriorModel& downcast_prior(AbstractPriorModel& prior_);
ProtoHypersPtr post_hypers = std::make_shared<ProtoHypers>();
};

// Methods' definitions
Expand Down Expand Up @@ -76,11 +73,4 @@ void SemiConjugateUpdater<Likelihood, PriorModel>::draw(
}
}

template <class Likelihood, class PriorModel>
void SemiConjugateUpdater<Likelihood, PriorModel>::save_posterior_hypers(
ProtoHypersPtr post_hypers_) {
post_hypers = post_hypers_;
return;
}

#endif // BAYESMIX_HIERARCHIES_UPDATERS_SEMI_CONJUGATE_UPDATER_H_

0 comments on commit 8ddf06a

Please sign in to comment.