Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(interpolation): fix spline bug #1541

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,13 @@
#ifndef INTERPOLATION__SPLINE_INTERPOLATION_HPP_
#define INTERPOLATION__SPLINE_INTERPOLATION_HPP_

#include "interpolation/interpolation_utils.hpp"
#include "tier4_autoware_utils/geometry/geometry.hpp"
#include <Eigen/Dense>

#include <algorithm>
#include <cmath>
#include <iostream>
#include <numeric>
#include <vector>

namespace interpolation
{
// NOTE: X(s) = a_i (s - s_i)^3 + b_i (s - s_i)^2 + c_i (s - s_i) + d_i : (i = 0, 1, ... N-1)
struct MultiSplineCoef
{
MultiSplineCoef() = default;

explicit MultiSplineCoef(const size_t num_spline)
{
a.resize(num_spline);
b.resize(num_spline);
c.resize(num_spline);
d.resize(num_spline);
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};

// static spline interpolation functions
std::vector<double> slerp(
Expand Down Expand Up @@ -84,8 +62,14 @@ class SplineInterpolation
std::vector<double> getSplineInterpolatedDiffValues(const std::vector<double> & query_keys) const;

private:
Eigen::VectorXd a_;
Eigen::VectorXd b_;
Eigen::VectorXd c_;
Eigen::VectorXd d_;

std::vector<double> base_keys_;
interpolation::MultiSplineCoef multi_spline_coef_;

Eigen::Index get_index(double key) const;
};

#endif // INTERPOLATION__SPLINE_INTERPOLATION_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#define INTERPOLATION__SPLINE_INTERPOLATION_POINTS_2D_HPP_

#include "interpolation/spline_interpolation.hpp"
#include "tier4_autoware_utils/geometry/geometry.hpp"

#include <geometry_msgs/msg/point.hpp>

#include <vector>

Expand Down
1 change: 1 addition & 0 deletions common/interpolation/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<license>Apache License 2.0</license>
<buildtool_depend>ament_cmake_auto</buildtool_depend>

<depend>eigen</depend>
<depend>tier4_autoware_utils</depend>

<test_depend>ament_lint_auto</test_depend>
Expand Down
210 changes: 89 additions & 121 deletions common/interpolation/src/spline_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,70 +14,43 @@

#include "interpolation/spline_interpolation.hpp"

#include <vector>
#include "interpolation/interpolation_utils.hpp"

namespace
{
// solve Ax = d
// where A is tridiagonal matrix
// [b_0 c_0 ... ]
// [a_0 b_1 c_1 ... O ]
// A = [ ... ]
// [ O ... a_N-3 b_N-2 c_N-2]
// [ ... a_N-2 b_N-1]
struct TDMACoef
#include <algorithm>

Eigen::VectorXd solve_tridiagonal_matrix_algorithm(
const Eigen::Ref<const Eigen::VectorXd> & a, const Eigen::Ref<const Eigen::VectorXd> & b,
const Eigen::Ref<const Eigen::VectorXd> & c, const Eigen::Ref<const Eigen::VectorXd> & d)
{
explicit TDMACoef(const size_t num_row)
{
a.resize(num_row - 1);
b.resize(num_row);
c.resize(num_row - 1);
d.resize(num_row);
auto n = d.size();

if (n == 1) {
return d.array() / b.array();
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};
Eigen::VectorXd c_prime = Eigen::VectorXd::Zero(n);
Eigen::VectorXd d_prime = Eigen::VectorXd::Zero(n);
Eigen::VectorXd x = Eigen::VectorXd::Zero(n);

inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
{
const auto & a = tdma_coef.a;
const auto & b = tdma_coef.b;
const auto & c = tdma_coef.c;
const auto & d = tdma_coef.d;

const size_t num_row = b.size();

std::vector<double> x(num_row);
if (num_row != 1) {
// calculate p and q
std::vector<double> p;
std::vector<double> q;
p.push_back(-c[0] / b[0]);
q.push_back(d[0] / b[0]);

for (size_t i = 1; i < num_row; ++i) {
const double den = b[i] + a[i - 1] * p[i - 1];
p.push_back(-c[i - 1] / den);
q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
}

// calculate solution
x[num_row - 1] = q[num_row - 1];

for (size_t i = 1; i < num_row; ++i) {
const size_t j = num_row - 1 - i;
x[j] = p[j] * x[j + 1] + q[j];
}
} else {
x.push_back(d[0] / b[0]);
// Forward sweep
c_prime(0) = c(0) / b(0);
d_prime(0) = d(0) / b(0);

for (auto i = 1; i < n; i++) {
double m = 1.0 / (b(i) - a(i - 1) * c_prime(i - 1));
c_prime(i) = i < n - 1 ? c(i) * m : 0;
d_prime(i) = (d(i) - a(i - 1) * d_prime(i - 1)) * m;
}

// Back substitution
x(n - 1) = d_prime(n - 1);

for (auto i = n - 2; i >= 0; i--) {
x(i) = d_prime(i) - c_prime(i) * x(i + 1);
}

return x;
}
} // namespace

namespace interpolation
{
Expand All @@ -101,73 +74,74 @@ void SplineInterpolation::calcSplineCoefficients(
// throw exceptions for invalid arguments
interpolation_utils::validateKeysAndValues(base_keys, base_values);

const size_t num_base = base_keys.size(); // N+1

std::vector<double> diff_keys; // N
std::vector<double> diff_values; // N
for (size_t i = 0; i < num_base - 1; ++i) {
diff_keys.push_back(base_keys.at(i + 1) - base_keys.at(i));
diff_values.push_back(base_values.at(i + 1) - base_values.at(i));
}

std::vector<double> v = {0.0};
if (num_base > 2) {
// solve tridiagonal matrix algorithm
TDMACoef tdma_coef(num_base - 2); // N-1

for (size_t i = 0; i < num_base - 2; ++i) {
tdma_coef.b[i] = 2 * (diff_keys[i] + diff_keys[i + 1]);
if (i != num_base - 3) {
tdma_coef.a[i] = diff_keys[i + 1];
tdma_coef.c[i] = diff_keys[i + 1];
}
tdma_coef.d[i] =
6.0 * (diff_values[i + 1] / diff_keys[i + 1] - diff_values[i] / diff_keys[i]);
}

const std::vector<double> tdma_res = solveTridiagonalMatrixAlgorithm(tdma_coef);

// calculate v
v.insert(v.end(), tdma_res.begin(), tdma_res.end());
}
v.push_back(0.0);

// calculate a, b, c, d of spline coefficients
multi_spline_coef_ = interpolation::MultiSplineCoef{num_base - 1}; // N
for (size_t i = 0; i < num_base - 1; ++i) {
multi_spline_coef_.a[i] = (v[i + 1] - v[i]) / 6.0 / diff_keys[i];
multi_spline_coef_.b[i] = v[i] / 2.0;
multi_spline_coef_.c[i] =
diff_values[i] / diff_keys[i] - diff_keys[i] * (2 * v[i] + v[i + 1]) / 6.0;
multi_spline_coef_.d[i] = base_values[i];
Eigen::VectorXd x = Eigen::Map<const Eigen::VectorXd>(
base_keys.data(), static_cast<Eigen::Index>(base_keys.size()));
Eigen::VectorXd y = Eigen::Map<const Eigen::VectorXd>(
base_values.data(), static_cast<Eigen::Index>(base_values.size()));

const auto n = x.size();

if (n == 2) {
a_ = Eigen::VectorXd::Zero(1);
b_ = Eigen::VectorXd::Zero(1);
c_ = Eigen::VectorXd::Zero(1);
d_ = Eigen::VectorXd::Zero(1);
c_[0] = (y[1] - y[0]) / (x[1] - x[0]);
d_[0] = y[0];
base_keys_ = base_keys;
return;
}

// Create Tridiagonal matrix
Eigen::VectorXd v(n);
Eigen::VectorXd h = x.segment(1, n - 1) - x.segment(0, n - 1);
Eigen::VectorXd a = h.segment(1, n - 3);
Eigen::VectorXd b = 2 * (h.segment(0, n - 2) + h.segment(1, n - 2));
Eigen::VectorXd c = h.segment(1, n - 3);
Eigen::VectorXd y_diff = y.segment(1, n - 1) - y.segment(0, n - 1);
Eigen::VectorXd d = 6 * (y_diff.segment(1, n - 2).array() / h.tail(n - 2).array() -
y_diff.segment(0, n - 2).array() / h.head(n - 2).array());

// Solve tridiagonal matrix
v.segment(1, n - 2) = solve_tridiagonal_matrix_algorithm(a, b, c, d);
v[0] = 0;
v[n - 1] = 0;

// Calculate spline coefficients
a_ = (v.tail(n - 1) - v.head(n - 1)).array() / 6.0 / (x.tail(n - 1) - x.head(n - 1)).array();
b_ = v.segment(0, n - 1) / 2.0;
c_ = (y.tail(n - 1) - y.head(n - 1)).array() / (x.tail(n - 1) - x.head(n - 1)).array() -
(x.tail(n - 1) - x.head(n - 1)).array() *
(2 * v.segment(0, n - 1).array() + v.segment(1, n - 1).array()) / 6.0;
d_ = y.head(n - 1);
base_keys_ = base_keys;
}

Eigen::Index SplineInterpolation::get_index(double key) const
{
auto it = std::lower_bound(base_keys_.begin(), base_keys_.end(), key);
return std::clamp(
static_cast<int>(std::distance(base_keys_.begin(), it)) - 1, 0,
static_cast<int>(base_keys_.size()) - 2);
}

std::vector<double> SplineInterpolation::getSplineInterpolatedValues(
const std::vector<double> & query_keys) const
{
// throw exceptions for invalid arguments
interpolation_utils::validateKeys(base_keys_, query_keys);

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;
const auto & d = multi_spline_coef_.d;
std::vector<double> interpolated_values;
interpolated_values.reserve(query_keys.size());

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

const double ds = query_key - base_keys_.at(j);
res.push_back(d.at(j) + (c.at(j) + (b.at(j) + a.at(j) * ds) * ds) * ds);
for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_values.emplace_back(
a_[idx] * dx * dx * dx + b_[idx] * dx * dx + c_[idx] * dx + d_[idx]);
}

return res;
return interpolated_values;
}

std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
Expand All @@ -176,20 +150,14 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
// throw exceptions for invalid arguments
interpolation_utils::validateKeys(base_keys_, query_keys);

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}
std::vector<double> interpolated_diff_values;
interpolated_diff_values.reserve(query_keys.size());

const double ds = query_key - base_keys_.at(j);
res.push_back(c.at(j) + (2.0 * b.at(j) + 3.0 * a.at(j) * ds) * ds);
for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_diff_values.emplace_back(3 * a_[idx] * dx * dx + 2 * b_[idx] * dx + c_[idx]);
}

return res;
return interpolated_diff_values;
}
Loading