Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Refactor #75

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ namespace gtsam {

class HybridValues;

/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;

KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
Expand Down
74 changes: 28 additions & 46 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
Expand All @@ -32,6 +33,16 @@

namespace gtsam {
/* *******************************************************************************/
/**
* @brief Helper struct for constructing HybridGaussianConditional objects
*
* This struct contains the following fields:
* - nrFrontals: Optional size_t for number of frontal variables
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
* - conditionals: Conditionals for storing conditionals. TODO(frank): kill!
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
* constructor
*/
struct HybridGaussianConditional::Helper {
std::optional<size_t> nrFrontals;
FactorValuePairs pairs;
Expand Down Expand Up @@ -68,16 +79,12 @@ struct HybridGaussianConditional::Helper {
explicit Helper(const Conditionals &conditionals)
: conditionals(conditionals),
minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr &c) -> GaussianFactorValuePair {
double value = 0.0;
if (c) {
if (!nrFrontals.has_value()) {
nrFrontals = c->nrFrontals();
}
value = c->negLogConstant();
minNegLogConstant = std::min(minNegLogConstant, value);
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
auto func = [this](const GC::shared_ptr& gc) -> GaussianFactorValuePair {
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
if (!nrFrontals) nrFrontals = gc->nrFrontals();
double value = gc->negLogConstant();
minNegLogConstant = std::min(minNegLogConstant, value);
return {gc, value};
};
pairs = FactorValuePairs(conditionals, func);
if (!nrFrontals.has_value()) {
Expand All @@ -90,8 +97,15 @@ struct HybridGaussianConditional::Helper {

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const Helper &helper)
: BaseFactor(discreteParents, helper.pairs),
const DiscreteKeys& discreteParents, const Helper& helper)
: BaseFactor(discreteParents,
FactorValuePairs(helper.pairs,
[&](const GaussianFactorValuePair&
pair) { // subtract minNegLogConstant
return GaussianFactorValuePair{
pair.first,
pair.second - helper.minNegLogConstant};
})),
BaseConditional(*helper.nrFrontals),
conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {}
Expand Down Expand Up @@ -135,29 +149,6 @@ HybridGaussianConditional::conditionals() const {
return conditionals_;
}

/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
const {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
// First check if conditional has not been pruned
if (gc) {
const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_;
// If there is a difference in the covariances, we need to account for
// that since the error is dependent on the mode.
if (Cgm_Kgcm > 0.0) {
// We add a constant factor which will be used when computing
// the probability of the discrete variables.
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = std::make_shared<JacobianFactor>(c);
return GaussianFactorGraph{gc, constantFactor};
}
}
return GaussianFactorGraph{gc};
};
return {conditionals_, wrap};
}

/* *******************************************************************************/
size_t HybridGaussianConditional::nrComponents() const {
size_t total = 0;
Expand Down Expand Up @@ -202,9 +193,6 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
void HybridGaussianConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
BaseConditional::print("", formatter);
std::cout << " Discrete Keys = ";
for (auto &dk : discreteKeys()) {
Expand Down Expand Up @@ -237,7 +225,7 @@ KeyVector HybridGaussianConditional::continuousParents() const {
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
continuousParentKeys.end());
}
return continuousParentKeys;
}
Expand Down Expand Up @@ -270,13 +258,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
-> GaussianFactorValuePair {
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
if (Cgm_Kgcm == 0.0) {
return {likelihood_m, 0.0};
} else {
// Add a constant to the likelihood in case the noise models
// are not all equal.
return {likelihood_m, Cgm_Kgcm};
}
return {likelihood_m, Cgm_Kgcm};
});
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
likelihoods);
Expand Down
4 changes: 1 addition & 3 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Conditional.h>
Expand Down Expand Up @@ -231,9 +232,6 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper);

/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

Expand Down
Loading
Loading