From 6c32bf4a9cfffabb9e17972370f767acc053be39 Mon Sep 17 00:00:00 2001 From: "Y.Hisaki" Date: Tue, 17 Sep 2024 18:11:24 +0900 Subject: [PATCH] fix spline bug Signed-off-by: Y.Hisaki --- .../interpolation/spline_interpolation.hpp | 32 +-- .../spline_interpolation_points_2d.hpp | 3 + common/interpolation/package.xml | 1 + .../src/spline_interpolation.cpp | 210 ++++++++---------- 4 files changed, 101 insertions(+), 145 deletions(-) diff --git a/common/interpolation/include/interpolation/spline_interpolation.hpp b/common/interpolation/include/interpolation/spline_interpolation.hpp index 09a01d03727eb..578b08a1fa225 100644 --- a/common/interpolation/include/interpolation/spline_interpolation.hpp +++ b/common/interpolation/include/interpolation/spline_interpolation.hpp @@ -15,35 +15,13 @@ #ifndef INTERPOLATION__SPLINE_INTERPOLATION_HPP_ #define INTERPOLATION__SPLINE_INTERPOLATION_HPP_ -#include "interpolation/interpolation_utils.hpp" -#include "tier4_autoware_utils/geometry/geometry.hpp" +#include -#include #include -#include -#include #include namespace interpolation { -// NOTE: X(s) = a_i (s - s_i)^3 + b_i (s - s_i)^2 + c_i (s - s_i) + d_i : (i = 0, 1, ... N-1) -struct MultiSplineCoef -{ - MultiSplineCoef() = default; - - explicit MultiSplineCoef(const size_t num_spline) - { - a.resize(num_spline); - b.resize(num_spline); - c.resize(num_spline); - d.resize(num_spline); - } - - std::vector a; - std::vector b; - std::vector c; - std::vector d; -}; // static spline interpolation functions std::vector slerp( @@ -84,8 +62,14 @@ class SplineInterpolation std::vector getSplineInterpolatedDiffValues(const std::vector & query_keys) const; private: + Eigen::VectorXd a_; + Eigen::VectorXd b_; + Eigen::VectorXd c_; + Eigen::VectorXd d_; + std::vector base_keys_; - interpolation::MultiSplineCoef multi_spline_coef_; + + Eigen::Index get_index(double key) const; }; #endif // INTERPOLATION__SPLINE_INTERPOLATION_HPP_ diff --git a/common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp b/common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp index c1f08a6d937ae..f46b64bba4d6a 100644 --- a/common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp +++ b/common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp @@ -16,6 +16,9 @@ #define INTERPOLATION__SPLINE_INTERPOLATION_POINTS_2D_HPP_ #include "interpolation/spline_interpolation.hpp" +#include "tier4_autoware_utils/geometry/geometry.hpp" + +#include #include diff --git a/common/interpolation/package.xml b/common/interpolation/package.xml index 72844e0702978..9f94f9221100e 100644 --- a/common/interpolation/package.xml +++ b/common/interpolation/package.xml @@ -9,6 +9,7 @@ Apache License 2.0 ament_cmake_auto + eigen tier4_autoware_utils ament_lint_auto diff --git a/common/interpolation/src/spline_interpolation.cpp b/common/interpolation/src/spline_interpolation.cpp index f8d14ff7bba37..bbcaafdef14e6 100644 --- a/common/interpolation/src/spline_interpolation.cpp +++ b/common/interpolation/src/spline_interpolation.cpp @@ -14,70 +14,43 @@ #include "interpolation/spline_interpolation.hpp" -#include +#include "interpolation/interpolation_utils.hpp" -namespace -{ -// solve Ax = d -// where A is tridiagonal matrix -// [b_0 c_0 ... ] -// [a_0 b_1 c_1 ... O ] -// A = [ ... ] -// [ O ... a_N-3 b_N-2 c_N-2] -// [ ... a_N-2 b_N-1] -struct TDMACoef +#include + +Eigen::VectorXd solve_tridiagonal_matrix_algorithm( + const Eigen::Ref & a, const Eigen::Ref & b, + const Eigen::Ref & c, const Eigen::Ref & d) { - explicit TDMACoef(const size_t num_row) - { - a.resize(num_row - 1); - b.resize(num_row); - c.resize(num_row - 1); - d.resize(num_row); + auto n = d.size(); + + if (n == 1) { + return d.array() / b.array(); } - std::vector a; - std::vector b; - std::vector c; - std::vector d; -}; + Eigen::VectorXd c_prime = Eigen::VectorXd::Zero(n); + Eigen::VectorXd d_prime = Eigen::VectorXd::Zero(n); + Eigen::VectorXd x = Eigen::VectorXd::Zero(n); -inline std::vector solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef) -{ - const auto & a = tdma_coef.a; - const auto & b = tdma_coef.b; - const auto & c = tdma_coef.c; - const auto & d = tdma_coef.d; - - const size_t num_row = b.size(); - - std::vector x(num_row); - if (num_row != 1) { - // calculate p and q - std::vector p; - std::vector q; - p.push_back(-c[0] / b[0]); - q.push_back(d[0] / b[0]); - - for (size_t i = 1; i < num_row; ++i) { - const double den = b[i] + a[i - 1] * p[i - 1]; - p.push_back(-c[i - 1] / den); - q.push_back((d[i] - a[i - 1] * q[i - 1]) / den); - } - - // calculate solution - x[num_row - 1] = q[num_row - 1]; - - for (size_t i = 1; i < num_row; ++i) { - const size_t j = num_row - 1 - i; - x[j] = p[j] * x[j + 1] + q[j]; - } - } else { - x.push_back(d[0] / b[0]); + // Forward sweep + c_prime(0) = c(0) / b(0); + d_prime(0) = d(0) / b(0); + + for (auto i = 1; i < n; i++) { + double m = 1.0 / (b(i) - a(i - 1) * c_prime(i - 1)); + c_prime(i) = i < n - 1 ? c(i) * m : 0; + d_prime(i) = (d(i) - a(i - 1) * d_prime(i - 1)) * m; + } + + // Back substitution + x(n - 1) = d_prime(n - 1); + + for (auto i = n - 2; i >= 0; i--) { + x(i) = d_prime(i) - c_prime(i) * x(i + 1); } return x; } -} // namespace namespace interpolation { @@ -101,73 +74,74 @@ void SplineInterpolation::calcSplineCoefficients( // throw exceptions for invalid arguments interpolation_utils::validateKeysAndValues(base_keys, base_values); - const size_t num_base = base_keys.size(); // N+1 - - std::vector diff_keys; // N - std::vector diff_values; // N - for (size_t i = 0; i < num_base - 1; ++i) { - diff_keys.push_back(base_keys.at(i + 1) - base_keys.at(i)); - diff_values.push_back(base_values.at(i + 1) - base_values.at(i)); - } - - std::vector v = {0.0}; - if (num_base > 2) { - // solve tridiagonal matrix algorithm - TDMACoef tdma_coef(num_base - 2); // N-1 - - for (size_t i = 0; i < num_base - 2; ++i) { - tdma_coef.b[i] = 2 * (diff_keys[i] + diff_keys[i + 1]); - if (i != num_base - 3) { - tdma_coef.a[i] = diff_keys[i + 1]; - tdma_coef.c[i] = diff_keys[i + 1]; - } - tdma_coef.d[i] = - 6.0 * (diff_values[i + 1] / diff_keys[i + 1] - diff_values[i] / diff_keys[i]); - } - - const std::vector tdma_res = solveTridiagonalMatrixAlgorithm(tdma_coef); - - // calculate v - v.insert(v.end(), tdma_res.begin(), tdma_res.end()); - } - v.push_back(0.0); - - // calculate a, b, c, d of spline coefficients - multi_spline_coef_ = interpolation::MultiSplineCoef{num_base - 1}; // N - for (size_t i = 0; i < num_base - 1; ++i) { - multi_spline_coef_.a[i] = (v[i + 1] - v[i]) / 6.0 / diff_keys[i]; - multi_spline_coef_.b[i] = v[i] / 2.0; - multi_spline_coef_.c[i] = - diff_values[i] / diff_keys[i] - diff_keys[i] * (2 * v[i] + v[i + 1]) / 6.0; - multi_spline_coef_.d[i] = base_values[i]; + Eigen::VectorXd x = Eigen::Map( + base_keys.data(), static_cast(base_keys.size())); + Eigen::VectorXd y = Eigen::Map( + base_values.data(), static_cast(base_values.size())); + + const auto n = x.size(); + + if (n == 2) { + a_ = Eigen::VectorXd::Zero(1); + b_ = Eigen::VectorXd::Zero(1); + c_ = Eigen::VectorXd::Zero(1); + d_ = Eigen::VectorXd::Zero(1); + c_[0] = (y[1] - y[0]) / (x[1] - x[0]); + d_[0] = y[0]; + base_keys_ = base_keys; + return; } + // Create Tridiagonal matrix + Eigen::VectorXd v(n); + Eigen::VectorXd h = x.segment(1, n - 1) - x.segment(0, n - 1); + Eigen::VectorXd a = h.segment(1, n - 3); + Eigen::VectorXd b = 2 * (h.segment(0, n - 2) + h.segment(1, n - 2)); + Eigen::VectorXd c = h.segment(1, n - 3); + Eigen::VectorXd y_diff = y.segment(1, n - 1) - y.segment(0, n - 1); + Eigen::VectorXd d = 6 * (y_diff.segment(1, n - 2).array() / h.tail(n - 2).array() - + y_diff.segment(0, n - 2).array() / h.head(n - 2).array()); + + // Solve tridiagonal matrix + v.segment(1, n - 2) = solve_tridiagonal_matrix_algorithm(a, b, c, d); + v[0] = 0; + v[n - 1] = 0; + + // Calculate spline coefficients + a_ = (v.tail(n - 1) - v.head(n - 1)).array() / 6.0 / (x.tail(n - 1) - x.head(n - 1)).array(); + b_ = v.segment(0, n - 1) / 2.0; + c_ = (y.tail(n - 1) - y.head(n - 1)).array() / (x.tail(n - 1) - x.head(n - 1)).array() - + (x.tail(n - 1) - x.head(n - 1)).array() * + (2 * v.segment(0, n - 1).array() + v.segment(1, n - 1).array()) / 6.0; + d_ = y.head(n - 1); base_keys_ = base_keys; } +Eigen::Index SplineInterpolation::get_index(double key) const +{ + auto it = std::lower_bound(base_keys_.begin(), base_keys_.end(), key); + return std::clamp( + static_cast(std::distance(base_keys_.begin(), it)) - 1, 0, + static_cast(base_keys_.size()) - 2); +} + std::vector SplineInterpolation::getSplineInterpolatedValues( const std::vector & query_keys) const { // throw exceptions for invalid arguments interpolation_utils::validateKeys(base_keys_, query_keys); - const auto & a = multi_spline_coef_.a; - const auto & b = multi_spline_coef_.b; - const auto & c = multi_spline_coef_.c; - const auto & d = multi_spline_coef_.d; + std::vector interpolated_values; + interpolated_values.reserve(query_keys.size()); - std::vector res; - size_t j = 0; - for (const auto & query_key : query_keys) { - while (base_keys_.at(j + 1) < query_key) { - ++j; - } - - const double ds = query_key - base_keys_.at(j); - res.push_back(d.at(j) + (c.at(j) + (b.at(j) + a.at(j) * ds) * ds) * ds); + for (const auto & key : query_keys) { + const auto idx = get_index(key); + const auto dx = key - base_keys_[idx]; + interpolated_values.emplace_back( + a_[idx] * dx * dx * dx + b_[idx] * dx * dx + c_[idx] * dx + d_[idx]); } - return res; + return interpolated_values; } std::vector SplineInterpolation::getSplineInterpolatedDiffValues( @@ -176,20 +150,14 @@ std::vector SplineInterpolation::getSplineInterpolatedDiffValues( // throw exceptions for invalid arguments interpolation_utils::validateKeys(base_keys_, query_keys); - const auto & a = multi_spline_coef_.a; - const auto & b = multi_spline_coef_.b; - const auto & c = multi_spline_coef_.c; - - std::vector res; - size_t j = 0; - for (const auto & query_key : query_keys) { - while (base_keys_.at(j + 1) < query_key) { - ++j; - } + std::vector interpolated_diff_values; + interpolated_diff_values.reserve(query_keys.size()); - const double ds = query_key - base_keys_.at(j); - res.push_back(c.at(j) + (2.0 * b.at(j) + 3.0 * a.at(j) * ds) * ds); + for (const auto & key : query_keys) { + const auto idx = get_index(key); + const auto dx = key - base_keys_[idx]; + interpolated_diff_values.emplace_back(3 * a_[idx] * dx * dx + 2 * b_[idx] * dx + c_[idx]); } - return res; + return interpolated_diff_values; }