Skip to content

Commit

Permalink
Add the two-filters method for Kalman Smoothing (#788)
Browse files Browse the repository at this point in the history
  • Loading branch information
beomki-yeo authored Dec 19, 2024
1 parent c06d483 commit 8620769
Show file tree
Hide file tree
Showing 16 changed files with 453 additions and 78 deletions.
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ traccc_add_library( traccc_core core TYPE SHARED
"include/traccc/fitting/kalman_filter/kalman_fitter.hpp"
"include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
"include/traccc/fitting/kalman_filter/statistics_updater.hpp"
"include/traccc/fitting/kalman_filter/two_filters_smoother.hpp"
"include/traccc/fitting/details/fit_tracks.hpp"
"include/traccc/fitting/kalman_fitting_algorithm.hpp"
"src/fitting/kalman_fitting_algorithm.cpp"
Expand Down
16 changes: 16 additions & 0 deletions core/include/traccc/edm/track_parameters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,20 @@ inline void wrap_phi(bound_track_parameters& param) {
param.set_phi(phi);
}

/// Covariance inflation used for track fitting
TRACCC_HOST_DEVICE
inline void inflate_covariance(bound_track_parameters& param,
const traccc::scalar inf_fac) {
auto& cov = param.covariance();
for (unsigned int i = 0; i < e_bound_size; i++) {
for (unsigned int j = 0; j < e_bound_size; j++) {
if (i == j) {
getter::element(cov, i, i) *= inf_fac;
} else {
getter::element(cov, i, j) = 0.f;
}
}
}
}

} // namespace traccc
17 changes: 17 additions & 0 deletions core/include/traccc/edm/track_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ struct fitting_result {
// The number of holes (The number of sensitive surfaces which do not have a
// measurement for the track pattern)
unsigned int n_holes{0u};

/// Reset the statistics
TRACCC_HOST_DEVICE
void reset_statistics() {
ndf = 0.f;
chi2 = 0.f;
n_holes = 0u;
}
};

/// Fitting result per measurement
Expand Down Expand Up @@ -160,6 +168,14 @@ struct track_state {
TRACCC_HOST_DEVICE
inline const scalar_type& filtered_chi2() const { return m_filtered_chi2; }

/// @return the non-const chi square of backward filter
TRACCC_HOST_DEVICE
inline scalar_type& backward_chi2() { return m_backward_chi2; }

/// @return the const chi square of backward filter
TRACCC_HOST_DEVICE
inline scalar_type backward_chi2() const { return m_backward_chi2; }

/// @return the non-const filtered parameter
TRACCC_HOST_DEVICE
inline bound_track_parameters_type& filtered() { return m_filtered; }
Expand Down Expand Up @@ -200,6 +216,7 @@ struct track_state {
bound_track_parameters_type m_filtered;
scalar_type m_smoothed_chi2 = 0.f;
bound_track_parameters_type m_smoothed;
scalar_type m_backward_chi2 = 0.f;
};

/// Declare all track_state collection types
Expand Down
4 changes: 4 additions & 0 deletions core/include/traccc/fitting/fitting_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ struct fitting_config {
/// Particle hypothesis
detray::pdg_particle<traccc::scalar> ptc_hypothesis =
detray::muon<traccc::scalar>();

/// Smoothing with backward filter
bool use_backward_filter = false;
traccc::scalar covariance_inflation_factor = 1e3f;
};

} // namespace traccc
68 changes: 53 additions & 15 deletions core/include/traccc/fitting/kalman_filter/kalman_actor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
#include "traccc/utils/particle.hpp"

// detray include(s).
Expand All @@ -33,31 +34,50 @@ struct kalman_actor : detray::actor {
state(vector_t<track_state_type>&& track_states)
: m_track_states(std::move(track_states)) {
m_it = m_track_states.begin();
m_it_rev = m_track_states.rbegin();
}

/// Constructor with the vector of track states
TRACCC_HOST_DEVICE
state(const vector_t<track_state_type>& track_states)
: m_track_states(track_states) {
m_it = m_track_states.begin();
m_it_rev = m_track_states.rbegin();
}

/// @return the reference of track state pointed by the iterator
TRACCC_HOST_DEVICE
track_state_type& operator()() { return *m_it; }
track_state_type& operator()() {
if (!backward_mode) {
return *m_it;
} else {
return *m_it_rev;
}
}

/// Reset the iterator
TRACCC_HOST_DEVICE
void reset() { m_it = m_track_states.begin(); }
void reset() {
m_it = m_track_states.begin();
m_it_rev = m_track_states.rbegin();
}

/// Advance the iterator
TRACCC_HOST_DEVICE
void next() { m_it++; }
void next() {
if (!backward_mode) {
m_it++;
} else {
m_it_rev++;
}
}

/// @return true if the iterator reaches the end of vector
TRACCC_HOST_DEVICE
bool is_complete() const {
if (m_it == m_track_states.end()) {
bool is_complete() {
if (!backward_mode && m_it == m_track_states.end()) {
return true;
} else if (backward_mode && m_it_rev == m_track_states.rend()) {
return true;
}
return false;
Expand All @@ -69,9 +89,15 @@ struct kalman_actor : detray::actor {
// iterator for forward filtering
typename vector_t<track_state_type>::iterator m_it;

// iterator for backward filtering
typename vector_t<track_state_type>::reverse_iterator m_it_rev;

// The number of holes (The number of sensitive surfaces which do not
// have a measurement for the track pattern)
unsigned int n_holes{0u};

// Run back filtering for smoothing, if true
bool backward_mode = false;
};

/// Actor operation to perform the Kalman filtering
Expand Down Expand Up @@ -99,32 +125,44 @@ struct kalman_actor : detray::actor {
// Increase the hole counts if the propagator fails to find the next
// measurement
if (navigation.barcode() != trk_state.surface_link()) {
actor_state.n_holes++;
if (!actor_state.backward_mode) {
actor_state.n_holes++;
}
return;
}

// This track state is not a hole
trk_state.is_hole = false;
if (!actor_state.backward_mode) {
trk_state.is_hole = false;
}

// Run Kalman Gain Updater
const auto sf = navigation.get_surface();

const bool res =
sf.template visit_mask<gain_matrix_updater<algebra_t>>(
bool res = false;

if (!actor_state.backward_mode) {
// Forward filter
res = sf.template visit_mask<gain_matrix_updater<algebra_t>>(
trk_state, propagation._stepping.bound_params());

// Update the propagation flow
stepping.bound_params() = trk_state.filtered();

// Set full jacobian
trk_state.jacobian() = stepping.full_jacobian();
} else {
// Backward filter for smoothing
res = sf.template visit_mask<two_filters_smoother<algebra_t>>(
trk_state, propagation._stepping.bound_params());
}

// Abort if the Kalman update fails
if (!res) {
propagation._heartbeat &= navigation.abort();
return;
}

// Update the propagation flow
stepping.bound_params() = trk_state.filtered();

// Set full jacobian
trk_state.jacobian() = stepping.full_jacobian();

// Change the charge of hypothesized particles when the sign of qop
// is changed (This rarely happens when qop is set with a poor seed
// resolution)
Expand Down
94 changes: 69 additions & 25 deletions core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "traccc/fitting/kalman_filter/kalman_actor.hpp"
#include "traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
#include "traccc/fitting/kalman_filter/statistics_updater.hpp"
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
#include "traccc/utils/particle.hpp"

// detray include(s).
Expand Down Expand Up @@ -67,10 +68,17 @@ class kalman_fitter {
detray::actor_chain<detray::dtuple, aborter, transporter, interactor,
fit_actor, resetter, kalman_step_aborter>;

using backward_actor_chain_type =
detray::actor_chain<detray::dtuple, aborter, transporter, fit_actor,
interactor, resetter, kalman_step_aborter>;

// Propagator type
using propagator_type =
detray::propagator<stepper_t, navigator_t, actor_chain_type>;

using backward_propagator_type =
detray::propagator<stepper_t, navigator_t, backward_actor_chain_type>;

/// Constructor with a detector
///
/// @param det the detector object
Expand Down Expand Up @@ -104,6 +112,14 @@ class kalman_fitter {
m_resetter_state, m_step_aborter_state);
}

/// @return the actor chain state
TRACCC_HOST_DEVICE
typename backward_actor_chain_type::state backward_actor_state() {
return detray::tie(m_aborter_state, m_transporter_state,
m_fit_actor_state, m_interactor_state,
m_resetter_state, m_step_aborter_state);
}

/// Individual actor states
typename aborter::state m_aborter_state{};
typename transporter::state m_transporter_state{};
Expand Down Expand Up @@ -132,17 +148,15 @@ class kalman_fitter {
// Reset the iterator of kalman actor
fitter_state.m_fit_actor_state.reset();

if (i == 0) {
filter(seed_params, fitter_state);
}
// From the second iteration, seed parameter is the smoothed track
// parameter at the first surface
else {
const auto& new_seed_params =
fitter_state.m_fit_actor_state.m_track_states[0].smoothed();
auto seed_params_cpy =
(i == 0) ? seed_params
: fitter_state.m_fit_actor_state.m_track_states[0]
.smoothed();

filter(new_seed_params, fitter_state);
}
inflate_covariance(seed_params_cpy,
m_cfg.covariance_inflation_factor);

filter(seed_params_cpy, fitter_state);
}
}

Expand Down Expand Up @@ -178,6 +192,9 @@ class kalman_fitter {
.template set_constraint<detray::step::constraint::e_accuracy>(
m_cfg.propagation.stepping.step_constraint);

// Reset fitter statistics
fitter_state.m_fit_res.reset_statistics();

// Run forward filtering
propagator.propagate(propagation, fitter_state());

Expand All @@ -194,14 +211,10 @@ class kalman_fitter {
/// track and vertex fitting", R.Frühwirth, NIM A.
///
/// @param fitter_state the state of kalman fitter
TRACCC_HOST_DEVICE
void smooth(state& fitter_state) {
TRACCC_HOST_DEVICE void smooth(state& fitter_state) {

auto& track_states = fitter_state.m_fit_actor_state.m_track_states;

// The smoothing algorithm requires the following:
// (1) the filtered track parameter of the current surface
// (2) the smoothed track parameter of the next surface
//
// Since the smoothed track parameter of the last surface can be
// considered to be the filtered one, we can reversly iterate the
// algorithm to obtain the smoothed parameter of other surfaces
Expand All @@ -210,14 +223,45 @@ class kalman_fitter {
last.smoothed().set_covariance(last.filtered().covariance());
last.smoothed_chi2() = last.filtered_chi2();

for (typename vector_type<track_state<algebra_type>>::reverse_iterator
it = track_states.rbegin() + 1;
it != track_states.rend(); ++it) {
if (m_cfg.use_backward_filter) {
// Backward propagator for the two-filters method
backward_propagator_type propagator(m_cfg.propagation);

// Set path limit
fitter_state.m_aborter_state.set_path_limit(
m_cfg.propagation.stepping.path_limit);

typename backward_propagator_type::state propagation(
last.smoothed(), m_field, m_detector);

inflate_covariance(propagation._stepping.bound_params(),
m_cfg.covariance_inflation_factor);

propagation._navigation.set_volume(
last.smoothed().surface_link().volume());

// Run kalman smoother
const detray::tracking_surface sf{m_detector, it->surface_link()};
sf.template visit_mask<gain_matrix_smoother<algebra_type>>(
*it, *(it - 1));
propagation._navigation.set_direction(
detray::navigation::direction::e_backward);
fitter_state.m_fit_actor_state.backward_mode = true;

propagator.propagate(propagation,
fitter_state.backward_actor_state());

// Reset the backward mode to false
fitter_state.m_fit_actor_state.backward_mode = false;

} else {
// Run the Rauch–Tung–Striebel (RTS) smoother
for (typename vector_type<
track_state<algebra_type>>::reverse_iterator it =
track_states.rbegin() + 1;
it != track_states.rend(); ++it) {

const detray::tracking_surface sf{m_detector,
it->surface_link()};
sf.template visit_mask<gain_matrix_smoother<algebra_type>>(
*it, *(it - 1));
}
}
}

Expand All @@ -233,8 +277,8 @@ class kalman_fitter {

const detray::tracking_surface sf{m_detector,
trk_state.surface_link()};
sf.template visit_mask<statistics_updater<algebra_type>>(fit_res,
trk_state);
sf.template visit_mask<statistics_updater<algebra_type>>(
fit_res, trk_state, m_cfg.use_backward_filter);
}

// Subtract the NDoF with the degree of freedom of the bound track (=5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ struct statistics_updater {
TRACCC_HOST_DEVICE inline void operator()(
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
fitting_result<algebra_t>& fit_res,
const track_state<algebra_t>& trk_state) {
const track_state<algebra_t>& trk_state,
const bool use_backward_filter) {

if (!trk_state.is_hole) {

Expand All @@ -41,7 +42,11 @@ struct statistics_updater {
fit_res.ndf += static_cast<scalar_type>(D);

// total_chi2 = total_chi2 + chi2
fit_res.chi2 += trk_state.smoothed_chi2();
if (use_backward_filter) {
fit_res.chi2 += trk_state.backward_chi2();
} else {
fit_res.chi2 += trk_state.filtered_chi2();
}
}
}
};
Expand Down
Loading

0 comments on commit 8620769

Please sign in to comment.