Skip to content

Commit

Permalink
ekf2-drag: do not generate Kalman gain to save flash
Browse files Browse the repository at this point in the history
  • Loading branch information
bresch committed Jan 31, 2024
1 parent 7904ae7 commit 9040b15
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 492 deletions.
19 changes: 11 additions & 8 deletions src/modules/ekf2/EKF/drag_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
*/

#include "ekf.h"
#include <ekf_derivation/generated/compute_drag_x_innov_var_and_k.h>
#include <ekf_derivation/generated/compute_drag_y_innov_var_and_k.h>
#include <ekf_derivation/generated/compute_drag_x_innov_var_and_h.h>
#include <ekf_derivation/generated/compute_drag_y_innov_var_and_h.h>

#include <mathlib/mathlib.h>
#include <lib/atmosphere/atmosphere.h>
Expand Down Expand Up @@ -110,7 +110,7 @@ void Ekf::fuseDrag(const dragSample &drag_sample)

bool fused[] {false, false};

VectorState Kfusion;
VectorState H;

// perform sequential fusion of XY specific forces
for (uint8_t axis_index = 0; axis_index < 2; axis_index++) {
Expand All @@ -128,16 +128,16 @@ void Ekf::fuseDrag(const dragSample &drag_sample)
_aid_src_drag.innovation_variance[axis_index] = NAN; // reset

if (axis_index == 0) {
sym::ComputeDragXInnovVarAndK(state_vector_prev, P, rho, bcoef_inv(axis_index), mcoef_corrrected, R_ACC, FLT_EPSILON,
&_aid_src_drag.innovation_variance[axis_index], &Kfusion);
sym::ComputeDragXInnovVarAndH(state_vector_prev, P, rho, bcoef_inv(axis_index), mcoef_corrrected, R_ACC, FLT_EPSILON,
&_aid_src_drag.innovation_variance[axis_index], &H);

if (!using_bcoef_x && !using_mcoef) {
continue;
}

} else if (axis_index == 1) {
sym::ComputeDragYInnovVarAndK(state_vector_prev, P, rho, bcoef_inv(axis_index), mcoef_corrrected, R_ACC, FLT_EPSILON,
&_aid_src_drag.innovation_variance[axis_index], &Kfusion);
sym::ComputeDragYInnovVarAndH(state_vector_prev, P, rho, bcoef_inv(axis_index), mcoef_corrrected, R_ACC, FLT_EPSILON,
&_aid_src_drag.innovation_variance[axis_index], &H);

if (!using_bcoef_y && !using_mcoef) {
continue;
Expand All @@ -157,7 +157,10 @@ void Ekf::fuseDrag(const dragSample &drag_sample)
&& PX4_ISFINITE(_aid_src_drag.innovation_variance[axis_index]) && PX4_ISFINITE(_aid_src_drag.innovation[axis_index])
&& (_aid_src_drag.test_ratio[axis_index] < 1.f)
) {
if (measurementUpdate(Kfusion, _aid_src_drag.innovation_variance[axis_index], _aid_src_drag.innovation[axis_index])) {

VectorState K = P * H / _aid_src_drag.innovation_variance[axis_index];

if (measurementUpdate(K, _aid_src_drag.innovation_variance[axis_index], _aid_src_drag.innovation[axis_index])) {
fused[axis_index] = true;
}
}
Expand Down
14 changes: 6 additions & 8 deletions src/modules/ekf2/EKF/python/ekf_derivation/derivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def predict_drag(
return bluff_body_drag + momentum_drag


def compute_drag_x_innov_var_and_k(
def compute_drag_x_innov_var_and_h(
state: VState,
P: MTangent,
rho: sf.Scalar,
Expand All @@ -581,11 +581,10 @@ def compute_drag_x_innov_var_and_k(
meas_pred = predict_drag(state, rho, cd, cm, epsilon)
Hx = sf.V1(meas_pred[0]).jacobian(state)
innov_var = (Hx * P * Hx.T + R)[0,0]
K = P * Hx.T / sf.Max(innov_var, epsilon)

return (innov_var, K)
return (innov_var, Hx.T)

def compute_drag_y_innov_var_and_k(
def compute_drag_y_innov_var_and_h(
state: VState,
P: MTangent,
rho: sf.Scalar,
Expand All @@ -599,9 +598,8 @@ def compute_drag_y_innov_var_and_k(
meas_pred = predict_drag(state, rho, cd, cm, epsilon)
Hy = sf.V1(meas_pred[1]).jacobian(state)
innov_var = (Hy * P * Hy.T + R)[0,0]
K = P * Hy.T / sf.Max(innov_var, epsilon)

return (innov_var, K)
return (innov_var, Hy.T)

def predict_gravity_direction(state: State):
# get transform from earth to body frame
Expand Down Expand Up @@ -674,8 +672,8 @@ def compute_gravity_z_innov_var_and_h(
if not args.disable_wind:
generate_px4_function(compute_airspeed_h_and_k, output_names=["H", "K"])
generate_px4_function(compute_airspeed_innov_and_innov_var, output_names=["innov", "innov_var"])
generate_px4_function(compute_drag_x_innov_var_and_k, output_names=["innov_var", "K"])
generate_px4_function(compute_drag_y_innov_var_and_k, output_names=["innov_var", "K"])
generate_px4_function(compute_drag_x_innov_var_and_h, output_names=["innov_var", "Hx"])
generate_px4_function(compute_drag_y_innov_var_and_h, output_names=["innov_var", "Hy"])
generate_px4_function(compute_sideslip_h_and_k, output_names=["H", "K"])
generate_px4_function(compute_sideslip_innov_and_innov_var, output_names=["innov", "innov_var"])
generate_px4_function(compute_wind_init_and_cov_from_airspeed, output_names=["wind", "P_wind"])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// -----------------------------------------------------------------------------
// This file was autogenerated by symforce from template:
// function/FUNCTION.h.jinja
// Do NOT modify by hand.
// -----------------------------------------------------------------------------

#pragma once

#include <matrix/math.hpp>

namespace sym {

/**
* This function was autogenerated from a symbolic function. Do not modify by hand.
*
* Symbolic function: compute_drag_x_innov_var_and_h
*
* Args:
* state: Matrix24_1
* P: Matrix23_23
* rho: Scalar
* cd: Scalar
* cm: Scalar
* R: Scalar
* epsilon: Scalar
*
* Outputs:
* innov_var: Scalar
* Hx: Matrix23_1
*/
template <typename Scalar>
void ComputeDragXInnovVarAndH(const matrix::Matrix<Scalar, 24, 1>& state,
const matrix::Matrix<Scalar, 23, 23>& P, const Scalar rho,
const Scalar cd, const Scalar cm, const Scalar R,
const Scalar epsilon, Scalar* const innov_var = nullptr,
matrix::Matrix<Scalar, 23, 1>* const Hx = nullptr) {
// Total ops: 317

// Input arrays

// Intermediate terms (73)
const Scalar _tmp0 = 2 * state(3, 0);
const Scalar _tmp1 = _tmp0 * state(0, 0);
const Scalar _tmp2 = 2 * state(2, 0);
const Scalar _tmp3 = _tmp2 * state(1, 0);
const Scalar _tmp4 = _tmp1 + _tmp3;
const Scalar _tmp5 = _tmp4 * cm;
const Scalar _tmp6 = std::pow(state(3, 0), Scalar(2));
const Scalar _tmp7 = -2 * _tmp6;
const Scalar _tmp8 = std::pow(state(2, 0), Scalar(2));
const Scalar _tmp9 = -2 * _tmp8;
const Scalar _tmp10 = _tmp7 + _tmp9 + 1;
const Scalar _tmp11 = -state(22, 0) + state(4, 0);
const Scalar _tmp12 = -state(23, 0) + state(5, 0);
const Scalar _tmp13 = _tmp2 * state(0, 0);
const Scalar _tmp14 = -_tmp13;
const Scalar _tmp15 = _tmp0 * state(1, 0);
const Scalar _tmp16 = _tmp14 + _tmp15;
const Scalar _tmp17 = _tmp12 * _tmp4 + _tmp16 * state(6, 0);
const Scalar _tmp18 = _tmp10 * _tmp11 + _tmp17;
const Scalar _tmp19 = 2 * _tmp18;
const Scalar _tmp20 = _tmp19 * _tmp4;
const Scalar _tmp21 = _tmp0 * state(2, 0);
const Scalar _tmp22 = 2 * state(0, 0) * state(1, 0);
const Scalar _tmp23 = -_tmp22;
const Scalar _tmp24 = _tmp21 + _tmp23;
const Scalar _tmp25 = std::pow(state(1, 0), Scalar(2));
const Scalar _tmp26 = 1 - 2 * _tmp25;
const Scalar _tmp27 = _tmp26 + _tmp9;
const Scalar _tmp28 = _tmp13 + _tmp15;
const Scalar _tmp29 = _tmp11 * _tmp28 + _tmp12 * _tmp24;
const Scalar _tmp30 = _tmp27 * state(6, 0) + _tmp29;
const Scalar _tmp31 = 2 * _tmp30;
const Scalar _tmp32 = _tmp24 * _tmp31;
const Scalar _tmp33 = _tmp26 + _tmp7;
const Scalar _tmp34 = -_tmp1;
const Scalar _tmp35 = _tmp3 + _tmp34;
const Scalar _tmp36 = _tmp21 + _tmp22;
const Scalar _tmp37 = _tmp11 * _tmp35 + _tmp36 * state(6, 0);
const Scalar _tmp38 = _tmp12 * _tmp33 + _tmp37;
const Scalar _tmp39 = 2 * _tmp38;
const Scalar _tmp40 = _tmp33 * _tmp39;
const Scalar _tmp41 = std::sqrt(Scalar(std::pow(_tmp18, Scalar(2)) + std::pow(_tmp30, Scalar(2)) +
std::pow(_tmp38, Scalar(2)) + epsilon));
const Scalar _tmp42 = cd * rho;
const Scalar _tmp43 = Scalar(0.25) * _tmp18 * _tmp42 / _tmp41;
const Scalar _tmp44 = Scalar(0.5) * _tmp41 * _tmp42;
const Scalar _tmp45 = _tmp4 * _tmp44;
const Scalar _tmp46 = -_tmp43 * (_tmp20 + _tmp32 + _tmp40) - _tmp45 - _tmp5;
const Scalar _tmp47 = -_tmp25;
const Scalar _tmp48 = _tmp47 + _tmp6;
const Scalar _tmp49 = std::pow(state(0, 0), Scalar(2));
const Scalar _tmp50 = -_tmp49;
const Scalar _tmp51 = _tmp50 + _tmp8;
const Scalar _tmp52 = -_tmp3;
const Scalar _tmp53 = -_tmp15;
const Scalar _tmp54 = -_tmp6;
const Scalar _tmp55 = _tmp12 * (_tmp47 + _tmp49 + _tmp54 + _tmp8) + _tmp37;
const Scalar _tmp56 = -_tmp43 * (_tmp19 * _tmp55 + _tmp39 * (_tmp11 * (_tmp48 + _tmp51) +
_tmp12 * (_tmp34 + _tmp52) +
state(6, 0) * (_tmp13 + _tmp53))) -
_tmp44 * _tmp55 - _tmp55 * cm;
const Scalar _tmp57 = -_tmp43 * (-_tmp20 - _tmp32 - _tmp40) + _tmp45 + _tmp5;
const Scalar _tmp58 = _tmp10 * cm;
const Scalar _tmp59 = _tmp10 * _tmp19;
const Scalar _tmp60 = _tmp28 * _tmp31;
const Scalar _tmp61 = _tmp35 * _tmp39;
const Scalar _tmp62 = _tmp10 * _tmp44;
const Scalar _tmp63 = -_tmp43 * (-_tmp59 - _tmp60 - _tmp61) + _tmp58 + _tmp62;
const Scalar _tmp64 = -_tmp8;
const Scalar _tmp65 = -_tmp21;
const Scalar _tmp66 = _tmp49 + _tmp64;
const Scalar _tmp67 =
_tmp43 * (_tmp31 * (_tmp11 * (_tmp1 + _tmp52) + _tmp12 * (_tmp25 + _tmp50 + _tmp6 + _tmp64) +
state(6, 0) * (_tmp23 + _tmp65)) +
_tmp39 * (_tmp29 + state(6, 0) * (_tmp48 + _tmp66)));
const Scalar _tmp68 = _tmp25 + _tmp54;
const Scalar _tmp69 =
_tmp11 * (_tmp14 + _tmp53) + _tmp12 * (_tmp22 + _tmp65) + state(6, 0) * (_tmp51 + _tmp68);
const Scalar _tmp70 =
-_tmp43 * (_tmp19 * _tmp69 + _tmp31 * (_tmp11 * (_tmp66 + _tmp68) + _tmp17)) -
_tmp44 * _tmp69 - _tmp69 * cm;
const Scalar _tmp71 = -_tmp43 * (_tmp59 + _tmp60 + _tmp61) - _tmp58 - _tmp62;
const Scalar _tmp72 = -_tmp16 * _tmp44 - _tmp16 * cm -
_tmp43 * (_tmp16 * _tmp19 + _tmp27 * _tmp31 + _tmp36 * _tmp39);

// Output terms (2)
if (innov_var != nullptr) {
Scalar& _innov_var = (*innov_var);

_innov_var =
R +
_tmp46 * (-P(0, 4) * _tmp67 + P(1, 4) * _tmp70 + P(2, 4) * _tmp56 + P(21, 4) * _tmp63 +
P(22, 4) * _tmp57 + P(3, 4) * _tmp71 + P(4, 4) * _tmp46 + P(5, 4) * _tmp72) +
_tmp56 * (-P(0, 2) * _tmp67 + P(1, 2) * _tmp70 + P(2, 2) * _tmp56 + P(21, 2) * _tmp63 +
P(22, 2) * _tmp57 + P(3, 2) * _tmp71 + P(4, 2) * _tmp46 + P(5, 2) * _tmp72) +
_tmp57 * (-P(0, 22) * _tmp67 + P(1, 22) * _tmp70 + P(2, 22) * _tmp56 + P(21, 22) * _tmp63 +
P(22, 22) * _tmp57 + P(3, 22) * _tmp71 + P(4, 22) * _tmp46 + P(5, 22) * _tmp72) +
_tmp63 * (-P(0, 21) * _tmp67 + P(1, 21) * _tmp70 + P(2, 21) * _tmp56 + P(21, 21) * _tmp63 +
P(22, 21) * _tmp57 + P(3, 21) * _tmp71 + P(4, 21) * _tmp46 + P(5, 21) * _tmp72) -
_tmp67 * (-P(0, 0) * _tmp67 + P(1, 0) * _tmp70 + P(2, 0) * _tmp56 + P(21, 0) * _tmp63 +
P(22, 0) * _tmp57 + P(3, 0) * _tmp71 + P(4, 0) * _tmp46 + P(5, 0) * _tmp72) +
_tmp70 * (-P(0, 1) * _tmp67 + P(1, 1) * _tmp70 + P(2, 1) * _tmp56 + P(21, 1) * _tmp63 +
P(22, 1) * _tmp57 + P(3, 1) * _tmp71 + P(4, 1) * _tmp46 + P(5, 1) * _tmp72) +
_tmp71 * (-P(0, 3) * _tmp67 + P(1, 3) * _tmp70 + P(2, 3) * _tmp56 + P(21, 3) * _tmp63 +
P(22, 3) * _tmp57 + P(3, 3) * _tmp71 + P(4, 3) * _tmp46 + P(5, 3) * _tmp72) +
_tmp72 * (-P(0, 5) * _tmp67 + P(1, 5) * _tmp70 + P(2, 5) * _tmp56 + P(21, 5) * _tmp63 +
P(22, 5) * _tmp57 + P(3, 5) * _tmp71 + P(4, 5) * _tmp46 + P(5, 5) * _tmp72);
}

if (Hx != nullptr) {
matrix::Matrix<Scalar, 23, 1>& _hx = (*Hx);

_hx.setZero();

_hx(0, 0) = -_tmp67;
_hx(1, 0) = _tmp70;
_hx(2, 0) = _tmp56;
_hx(3, 0) = _tmp71;
_hx(4, 0) = _tmp46;
_hx(5, 0) = _tmp72;
_hx(21, 0) = _tmp63;
_hx(22, 0) = _tmp57;
}
} // NOLINT(readability/fn_size)

// NOLINTNEXTLINE(readability/fn_size)
} // namespace sym
Loading

0 comments on commit 9040b15

Please sign in to comment.