Skip to content

Commit

Permalink
Refactor to shared_ptr: part3.1 change VINSFrame* to shared_ptr: fini…
Browse files Browse the repository at this point in the history
…sh D2VINS.
  • Loading branch information
xuhao3e8 committed Dec 12, 2024
1 parent 4c0e48f commit 449924d
Show file tree
Hide file tree
Showing 14 changed files with 127 additions and 122 deletions.
3 changes: 2 additions & 1 deletion d2common/include/d2common/d2imu.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace D2Common {
typedef std::lock_guard<std::recursive_mutex> Guard;

struct VINSFrame;
using VINSFramePtr = std::shared_ptr<VINSFrame>;

struct IMUData {
static Vector3d Gravity;
Expand Down Expand Up @@ -96,7 +97,7 @@ class IMUBuffer {
std::pair<IMUBuffer, int> periodIMU(int i0, double t1) const;

Swarm::Odometry propagation(const Swarm::Odometry & odom, const Vector3d & Ba, const Vector3d & Bg) const;
Swarm::Odometry propagation(const VINSFrame & baseframe) const;
Swarm::Odometry propagation(const VINSFramePtr & baseframe) const;
IMUData operator[](int i) const {
return buf.at(i);
}
Expand Down
4 changes: 2 additions & 2 deletions d2common/include/d2common/d2vinsframe.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ struct VINSFrame: public D2BaseFrame {
VINSFrame():Ba(0., 0., 0.), Bg(0., 0., 0.)
{}

VINSFrame(const VisualImageDescArray & frame, const IMUBuffer & buf, const VINSFrame & prev_frame);
VINSFrame(const VisualImageDescArray & frame, const std::pair<IMUBuffer, int> & buf, const VINSFrame & prev_frame);
VINSFrame(const VisualImageDescArray & frame, const IMUBuffer & buf, const std::shared_ptr<VINSFrame>& prev_frame);
VINSFrame(const VisualImageDescArray & frame, const std::pair<IMUBuffer, int> & buf, const std::shared_ptr<VINSFrame>& prev_frame);

VINSFrame(const VisualImageDescArray & frame, const Vector3d & _Ba, const Vector3d & _Bg);
VINSFrame(const VisualImageDescArray & frame);
Expand Down
1 change: 1 addition & 0 deletions d2common/include/d2common/solver/ConsensusSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class ConsensusSolver : public SolverWrapper {

virtual void addResidual(const std::shared_ptr<ResidualInfo>& residual_info) override;
SolverReport solve() override;
virtual SolverReport solve(std::function<void()> func_set_properties) override { assert(false && "Unused");}
void setToken(int token) {
solver_token = token;
}
Expand Down
9 changes: 6 additions & 3 deletions d2common/include/d2common/solver/SolverWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <d2common/d2state.hpp>
#include <d2common/solver/BaseParamResInfo.hpp>
#include "spdlog/spdlog.h"
#include <functional>

namespace D2Common {
class ResidualInfo;
Expand Down Expand Up @@ -45,7 +46,8 @@ class SolverWrapper {
public:
SolverWrapper(D2State * _state);
virtual void addResidual(const std::shared_ptr<ResidualInfo>& residual_info);
virtual SolverReport solve() = 0;
virtual SolverReport solve() = 0; // TODO: remove
virtual SolverReport solve(std::function<void()> func_set_properties) = 0;
ceres::Problem & getProblem();
virtual void reset();
};
Expand All @@ -56,8 +58,9 @@ class CeresSolver : public SolverWrapper {
public:
CeresSolver(D2State * _state, ceres::Solver::Options _options):
SolverWrapper(_state), options(_options) {}
virtual void addResidual(const std::shared_ptr<ResidualInfo>& residual_info) override;
SolverReport solve() override;
//TODO: set as override
virtual SolverReport solve() override { assert(false && "Unused");};
virtual SolverReport solve(std::function<void()> func_set_properties) override;
};

}
4 changes: 2 additions & 2 deletions d2common/src/d2imu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ std::pair<IMUBuffer, int> IMUBuffer::periodIMU(int i0, double t1) const {
return std::make_pair(slice(i0 + 1, i1 + 1), i1 + 1);
}

Swarm::Odometry IMUBuffer::propagation(const VINSFrame& baseframe) const {
return propagation(baseframe.odom, baseframe.Ba, baseframe.Bg);
Swarm::Odometry IMUBuffer::propagation(const VINSFramePtr& baseframe) const {
return propagation(baseframe->odom, baseframe->Ba, baseframe->Bg);
}

Swarm::Odometry IMUBuffer::propagation(const Swarm::Odometry& prev_odom,
Expand Down
16 changes: 8 additions & 8 deletions d2common/src/d2vinsframe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
namespace D2Common {
double t0 = 0;
VINSFrame::VINSFrame(const VisualImageDescArray& frame, const IMUBuffer& buf,
const VINSFrame& prev_frame)
const std::shared_ptr<VINSFrame>& prev_frame)
: D2BaseFrame(frame.stamp, frame.frame_id, frame.drone_id,
frame.reference_frame_id, frame.is_keyframe,
frame.pose_drone),
Ba(prev_frame.Ba),
Bg(prev_frame.Bg),
prev_frame_id(prev_frame.frame_id) {
Ba(prev_frame->Ba),
Bg(prev_frame->Bg),
prev_frame_id(prev_frame->frame_id) {
pre_integrations = std::make_shared<IntegrationBase>(buf, Ba, Bg);
if (t0 == 0) {
t0 = stamp;
Expand All @@ -19,13 +19,13 @@ VINSFrame::VINSFrame(const VisualImageDescArray& frame, const IMUBuffer& buf,

VINSFrame::VINSFrame(const VisualImageDescArray& frame,
const std::pair<IMUBuffer, int>& buf,
const VINSFrame& prev_frame)
const std::shared_ptr<VINSFrame>& prev_frame)
: D2BaseFrame(frame.stamp, frame.frame_id, frame.drone_id,
frame.reference_frame_id, frame.is_keyframe,
frame.pose_drone),
Ba(prev_frame.Ba),
Bg(prev_frame.Bg),
prev_frame_id(prev_frame.frame_id),
Ba(prev_frame->Ba),
Bg(prev_frame->Bg),
prev_frame_id(prev_frame->frame_id),
imu_buf_index(buf.second) {
pre_integrations = std::make_shared<IntegrationBase>(buf.first, Ba, Bg);
if (t0 == 0) {
Expand Down
19 changes: 11 additions & 8 deletions d2common/src/solver/SolverWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ void SolverWrapper::reset() {
residuals.clear();
}

void CeresSolver::addResidual(const std::shared_ptr<ResidualInfo>& residual_info) {
auto pointers = residual_info->paramsPointerList(state);
// printf("Add residual info %d", residual_info->residual_type);
problem->AddResidualBlock(CheckGetPtr(residual_info->cost_function),
SolverReport CeresSolver::solve(std::function<void()> func_set_properties) {
for (auto residual_info: residuals)
{
// Put here to avoid unable to check pointer been free
auto pointers = residual_info->paramsPointerList(state);
problem->AddResidualBlock(CheckGetPtr(residual_info->cost_function),
residual_info->loss_function.get(), pointers); // loss_function maybe nullptr
SolverWrapper::addResidual(residual_info);
}

SolverReport CeresSolver::solve() {
}
if (func_set_properties)
{
func_set_properties();
}
ceres::Solver::Summary summary;
ceres::Solve(options, problem, &summary);
SolverReport report;
Expand Down
3 changes: 1 addition & 2 deletions d2pgo/src/d2pgo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ bool D2PGO::solve_single() {
if (config.enable_gravity_prior) {
setupGravityPriorFactors(solver);
}
setStateProperties(solver->getProblem());
auto report = solver->solve();
auto report = solver->solve([&](){setStateProperties(solver->getProblem());});
if (config.perturb_mode) {
postPerturbSolve();
} else {
Expand Down
Loading

0 comments on commit 449924d

Please sign in to comment.