Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add GSF component momentum cut to stabilize fit #2661

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions Core/include/Acts/TrackFitting/GaussianSumFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ struct GaussianSumFitter {
/// The actor type
using GsfActor = detail::GsfActor<bethe_heitler_approx_t, traj_t>;

/// This allows to break the propagation by setting the navigationBreak
/// TODO refactor once we can do this more elegantly
struct NavigationBreakAborter {
NavigationBreakAborter() = default;

template <typename propagator_state_t, typename stepper_t,
typename navigator_t>
bool operator()(propagator_state_t& state, const stepper_t& /*stepper*/,
const navigator_t& navigator,
const Logger& /*logger*/) const {
return navigator.navigationBreak(state.navigation);
}
};

/// @brief The fit function for the Direct navigator
template <typename source_link_it_t, typename start_parameters_t,
typename track_container_t, template <typename> class holder_t>
Expand All @@ -92,7 +106,7 @@ struct GaussianSumFitter {
// Initialize the forward propagation with the DirectNavigator
auto fwdPropInitializer = [&sSequence, this](const auto& opts) {
using Actors = ActionList<GsfActor, DirectNavigator::Initializer>;
using Aborters = AbortList<>;
using Aborters = AbortList<NavigationBreakAborter>;

PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
opts.magFieldContext);
Expand Down Expand Up @@ -147,7 +161,7 @@ struct GaussianSumFitter {
// Initialize the forward propagation with the DirectNavigator
auto fwdPropInitializer = [this](const auto& opts) {
using Actors = ActionList<GsfActor>;
using Aborters = AbortList<EndOfWorldReached>;
using Aborters = AbortList<EndOfWorldReached, NavigationBreakAborter>;

PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
opts.magFieldContext);
Expand Down Expand Up @@ -370,7 +384,12 @@ struct GaussianSumFitter {
r.measurementStates++;
r.processedStates++;

const auto& params = *fwdGsfResult.lastMeasurementState;
assert(!fwdGsfResult.lastMeasurementComponents.empty());
assert(fwdGsfResult.lastMeasurementSurface != nullptr);
MultiComponentBoundTrackParameters params(
fwdGsfResult.lastMeasurementSurface->getSharedPtr(),
fwdGsfResult.lastMeasurementComponents,
sParameters.particleHypothesis());

return m_propagator.template propagate<std::decay_t<decltype(params)>,
decltype(bwdPropOptions),
Expand Down
2 changes: 2 additions & 0 deletions Core/include/Acts/TrackFitting/GsfOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ struct GsfOptions {

double weightCutoff = 1.e-4;

double momentumCutoff = 500_MeV;

bool abortOnError = false;

bool disableAllMaterialHandling = false;
Expand Down
160 changes: 86 additions & 74 deletions Core/include/Acts/TrackFitting/detail/GsfActor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ struct GsfResult {

/// The last multi-component measurement state. Used to initialize the
/// backward pass.
std::optional<MultiComponentBoundTrackParameters> lastMeasurementState;
std::vector<std::tuple<double, BoundVector, BoundMatrix>>
lastMeasurementComponents;

/// The last measurement surface. Used to initialize the backward pass.
const Acts::Surface* lastMeasurementSurface = nullptr;

/// Some counting
std::size_t measurementStates = 0;
Expand Down Expand Up @@ -94,6 +98,9 @@ struct GsfActor {
/// When to discard components
double weightCutoff = 1.0e-4;

/// When to discard components
double momentumCutoff = 500_MeV;

/// When this option is enabled, material information on all surfaces is
/// ignored. This disables the component convolution as well as the handling
/// of energy. This may be useful for debugging.
Expand Down Expand Up @@ -132,6 +139,8 @@ struct GsfActor {
std::map<MultiTrajectoryTraits::IndexType, double> weights;
};

using FiltProjector = MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;

/// @brief GSF actor operation
///
/// @tparam propagator_state_t is the type of Propagator state
Expand Down Expand Up @@ -260,7 +269,8 @@ struct GsfActor {
return;
}

updateStepper(state, stepper, tmpStates);
FiltProjector proj{tmpStates.traj, tmpStates.weights};
updateStepper(state, stepper, navigator, tmpStates.tips, proj);
}
// We have material, we thus need a component cache since we will
// convolute the components and later reduce them again before updating
Expand Down Expand Up @@ -305,7 +315,8 @@ struct GsfActor {

removeLowWeightComponents(componentCache);

updateStepper(state, stepper, navigator, componentCache);
auto proj = [](const auto& a) -> decltype(a) { return a; };
updateStepper(state, stepper, navigator, componentCache, proj);
}

// If we have only done preUpdate before, now do postUpdate
Expand All @@ -317,7 +328,8 @@ struct GsfActor {
// Break the navigation if we found all measurements
if (m_cfg.numberMeasurements &&
result.measurementStates == m_cfg.numberMeasurements) {
navigator.targetReached(state.navigation, true);
ACTS_VERBOSE("Stop navigation because all measurements are found");
navigator.navigationBreak(state.navigation, true);
}
}

Expand Down Expand Up @@ -404,6 +416,13 @@ struct GsfActor {
assert(p_prev + delta_p > 0. && "new momentum must be > 0");
new_pars[eBoundQOverP] = old_bound.charge() / (p_prev + delta_p);

const auto p_new = state.stepping.particleHypothesis.extractMomentum(
new_pars[eBoundQOverP]);
if (p_new < m_cfg.momentumCutoff) {
ACTS_VERBOSE("Skip new component with p=" << p_new << " GeV");
continue;
}

// compute inverse variance of p from mixture and update covariance
auto new_cov = old_bound.covariance().value();

Expand Down Expand Up @@ -450,50 +469,19 @@ struct GsfActor {
}
}

/// Function that updates the stepper from the MultiTrajectory
template <typename propagator_state_t, typename stepper_t>
void updateStepper(propagator_state_t& state, const stepper_t& stepper,
const TemporaryStates& tmpStates) const {
auto cmps = stepper.componentIterable(state.stepping);

for (auto [idx, cmp] : zip(tmpStates.tips, cmps)) {
// we set ignored components to missed, so we can remove them after
// the loop
if (tmpStates.weights.at(idx) < m_cfg.weightCutoff) {
cmp.status() = Intersection3D::Status::missed;
continue;
}

auto proxy = tmpStates.traj.getTrackState(idx);

cmp.pars() =
MultiTrajectoryHelpers::freeFiltered(state.options.geoContext, proxy);
cmp.cov() = proxy.filteredCovariance();
cmp.weight() = tmpStates.weights.at(idx);
}

stepper.removeMissedComponents(state.stepping);

// TODO we have two normalization passes here now, this can probably be
// optimized
detail::normalizeWeights(cmps,
[&](auto cmp) -> double& { return cmp.weight(); });
}

/// Function that updates the stepper from the ComponentCache
template <typename propagator_state_t, typename stepper_t,
typename navigator_t>
typename navigator_t, typename range_t, typename proj_t>
void updateStepper(propagator_state_t& state, const stepper_t& stepper,
const navigator_t& navigator,
const std::vector<ComponentCache>& componentCache) const {
const navigator_t& navigator, const range_t& range,
const proj_t& proj) const {
const auto& surface = *navigator.currentSurface(state.navigation);

// Clear components before adding new ones
stepper.clearComponents(state.stepping);

// Finally loop over components
for (const auto& [weight, pars, cov] : componentCache) {
// Add the component to the stepper
for (const auto& cmp : range) {
const auto& [weight, pars, cov] = proj(cmp);
BoundTrackParameters bound(surface.getSharedPtr(), pars, cov,
stepper.particleHypothesis(state.stepping));

Expand All @@ -504,12 +492,12 @@ struct GsfActor {
continue;
}

auto& cmp = *res;
cmp.jacToGlobal() = surface.boundToFreeJacobian(state.geoContext, pars);
cmp.pathAccumulated() = state.stepping.pathAccumulated;
cmp.jacobian() = Acts::BoundMatrix::Identity();
cmp.derivative() = Acts::FreeVector::Zero();
cmp.jacTransport() = Acts::FreeMatrix::Identity();
auto& proxy = *res;
proxy.jacToGlobal() = surface.boundToFreeJacobian(state.geoContext, pars);
proxy.pathAccumulated() = state.stepping.pathAccumulated;
proxy.jacobian() = Acts::BoundMatrix::Identity();
proxy.derivative() = Acts::FreeVector::Zero();
proxy.jacTransport() = Acts::FreeMatrix::Identity();
}
}

Expand All @@ -523,11 +511,14 @@ struct GsfActor {
const SourceLink& source_link) const {
const auto& surface = *navigator.currentSurface(state.navigation);

// This allows to easily project the to weight, filtered pars, filtered cov
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
const FiltProjector proj{tmpStates.traj, tmpStates.weights};

// Boolean flag, to distinguish measurement and outlier states. This flag
// is only modified by the valid-measurement-branch, so only if there
// isn't any valid measurement state, the flag stays false and the state
// is thus counted as an outlier
bool is_valid_measurement = false;
bool isValidMeasurement = false;

auto cmps = stepper.componentIterable(state.stepping);
for (auto cmp : cmps) {
Expand All @@ -545,54 +536,73 @@ struct GsfActor {

const auto& trackStateProxy = *trackStateProxyRes;

const auto p = state.stepping.particleHypothesis.extractMomentum(
trackStateProxy.filtered()[eBoundQOverP]);
if (p < m_cfg.momentumCutoff) {
ACTS_VERBOSE("Discard component with momentum "
<< p << " GeV after Kalman update");
continue;
}

// If at least one component is no outlier, we consider the whole thing
// as a measurementState
if (trackStateProxy.typeFlags().test(
Acts::TrackStateFlag::MeasurementFlag)) {
is_valid_measurement = true;
isValidMeasurement = true;
}

tmpStates.tips.push_back(trackStateProxy.index());
tmpStates.weights[tmpStates.tips.back()] = cmp.weight();
}

computePosteriorWeights(tmpStates.traj, tmpStates.tips, tmpStates.weights);
// compute the posterior weights
if (!tmpStates.tips.empty()) {
computePosteriorWeights(tmpStates.traj, tmpStates.tips,
tmpStates.weights);

detail::normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
return tmpStates.weights.at(idx);
});
detail::normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
return tmpStates.weights.at(idx);
});
}

// Remove low weight components
auto newEnd = std::remove_if(
tmpStates.tips.begin(), tmpStates.tips.end(),
[&](auto t) { return tmpStates.weights[t] < m_cfg.weightCutoff; });

if (newEnd != tmpStates.tips.end()) {
tmpStates.tips.erase(newEnd, tmpStates.tips.end());
detail::normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
return tmpStates.weights.at(idx);
});
}

// Break navigation if we have no states left
if (tmpStates.tips.empty()) {
ACTS_DEBUG("no components left after Kalman update, break navigation!");
navigator.navigationBreak(state.navigation, true);
return Acts::Result<void>::success();
}

// Do the statistics
++result.processedStates;

// TODO should outlier states also be counted here?
if (is_valid_measurement) {
if (isValidMeasurement) {
++result.measurementStates;
}

addCombinedState(result, tmpStates, surface);
result.lastMeasurementTip = result.currentTip;

using FiltProjector =
MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
FiltProjector proj{tmpStates.traj, tmpStates.weights};

std::vector<std::tuple<double, BoundVector, BoundMatrix>> v;
updateMultiTrajectory(result, tmpStates, surface);

// TODO Check why can zero weights can occur
result.lastMeasurementTip = result.currentTip;
result.lastMeasurementSurface = &surface;
result.lastMeasurementComponents.clear();
for (const auto& idx : tmpStates.tips) {
const auto [w, p, c] = proj(idx);
if (w > 0.0) {
v.push_back({w, p, c});
}
assert(w > 0.0);
result.lastMeasurementComponents.push_back({w, p, c});
}

normalizeWeights(v, [](auto& c) -> double& { return std::get<double>(c); });

result.lastMeasurementState = MultiComponentBoundTrackParameters(
surface.getSharedPtr(), std::move(v),
stepper.particleHypothesis(state.stepping));

// Return success
return Acts::Result<void>::success();
}
Expand Down Expand Up @@ -643,7 +653,7 @@ struct GsfActor {

++result.processedStates;

addCombinedState(result, tmpStates, surface);
updateMultiTrajectory(result, tmpStates, surface);

return Result<void>::success();
}
Expand Down Expand Up @@ -689,8 +699,9 @@ struct GsfActor {
}
}

void addCombinedState(result_type& result, const TemporaryStates& tmpStates,
const Surface& surface) const {
void updateMultiTrajectory(result_type& result,
const TemporaryStates& tmpStates,
const Surface& surface) const {
using PrtProjector =
MultiTrajectoryProjector<StatesType::ePredicted, traj_t>;
using FltProjector =
Expand Down Expand Up @@ -767,6 +778,7 @@ struct GsfActor {
m_cfg.abortOnError = options.abortOnError;
m_cfg.disableAllMaterialHandling = options.disableAllMaterialHandling;
m_cfg.weightCutoff = options.weightCutoff;
m_cfg.momentumCutoff = options.momentumCutoff;
m_cfg.mergeMethod = options.componentMergeMethod;
m_cfg.calibrationContext = &options.calibrationContext.get();
}
Expand Down
3 changes: 3 additions & 0 deletions Examples/Algorithms/TrackFitting/src/GsfFitterFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TrackingGeometry;
} // namespace Acts

using namespace ActsExamples;
using namespace Acts::UnitLiterals;

namespace {

Expand Down Expand Up @@ -101,6 +102,7 @@ struct GsfFitterFunctionImpl final : public ActsExamples::TrackFitterFunction {
&Acts::GainMatrixUpdater::operator()<Acts::VectorMultiTrajectory>>(
&updater);

const double momentumCutoff = 500_MeV;
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
Acts::GsfOptions<Acts::VectorMultiTrajectory> gsfOptions{
options.geoContext,
options.magFieldContext,
Expand All @@ -110,6 +112,7 @@ struct GsfFitterFunctionImpl final : public ActsExamples::TrackFitterFunction {
&(*options.referenceSurface),
maxComponents,
weightCutoff,
momentumCutoff,
abortOnError,
disableAllMaterialHandling};
gsfOptions.componentMergeMethod = mergeMethod;
Expand Down
3 changes: 2 additions & 1 deletion docs/core/reconstruction/track_fitting.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ outline:

The fit can be customized with several options. Important ones are:
* *maximum components*: How many components at maximum should be kept.
* *weight cut*: When to drop components.
* *weight cut*: When to drop components because of too little weight.
* *momentum cut*: When to drop components because of too low momentum. This can help stabilizing the fit, as low momenta tend to disturb the navigation.
* *component merging*: How a multi-component state is reduced to a single set of parameters and covariance. The method can be chosen with the enum {enum}`Acts::ComponentMergeMethod`. Two methods are supported currently:
* The *mean* computes the mean and the covariance of the mean.
* *max weight* takes the parameters of component with the maximum weight and computes the variance around these. This is a cheap approximation of the mode, which is not implemented currently.
Expand Down
Loading