Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add GSF component momentum cut to stabilize fit #2661

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Core/include/Acts/TrackFitting/GaussianSumFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,12 @@ struct GaussianSumFitter {
r.measurementStates++;
r.processedStates++;

const auto& params = *fwdGsfResult.lastMeasurementState;
assert(!fwdGsfResult.lastMeasurementComponents.empty());
assert(fwdGsfResult.lastMeasurementSurface != nullptr);
MultiComponentBoundTrackParameters params(
fwdGsfResult.lastMeasurementSurface->getSharedPtr(),
fwdGsfResult.lastMeasurementComponents,
sParameters.particleHypothesis());

return m_propagator.template propagate<std::decay_t<decltype(params)>,
decltype(bwdPropOptions),
Expand Down
2 changes: 2 additions & 0 deletions Core/include/Acts/TrackFitting/GsfOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ struct GsfOptions {

double weightCutoff = 1.e-4;

double momentumCutoff = 0;

bool abortOnError = false;

bool disableAllMaterialHandling = false;
Expand Down
165 changes: 87 additions & 78 deletions Core/include/Acts/TrackFitting/detail/GsfActor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ struct GsfResult {

/// The last multi-component measurement state. Used to initialize the
/// backward pass.
std::optional<MultiComponentBoundTrackParameters> lastMeasurementState;
std::vector<std::tuple<double, BoundVector, BoundMatrix>>
lastMeasurementComponents;

/// The last measurement surface. Used to initialize the backward pass.
const Acts::Surface* lastMeasurementSurface = nullptr;

/// Some counting
std::size_t measurementStates = 0;
Expand Down Expand Up @@ -96,6 +100,9 @@ struct GsfActor {
/// When to discard components
double weightCutoff = 1.0e-4;

/// When to discard components
double momentumCutoff = 500_MeV;

/// When this option is enabled, material information on all surfaces is
/// ignored. This disables the component convolution as well as the handling
/// of energy. This may be useful for debugging.
Expand Down Expand Up @@ -134,6 +141,8 @@ struct GsfActor {
std::map<MultiTrajectoryTraits::IndexType, double> weights;
};

using FiltProjector = MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;

/// @brief GSF actor operation
///
/// @tparam propagator_state_t is the type of Propagator state
Expand Down Expand Up @@ -262,7 +271,8 @@ struct GsfActor {
return;
}

updateStepper(state, stepper, tmpStates);
FiltProjector proj{tmpStates.traj, tmpStates.weights};
updateStepper(state, stepper, navigator, tmpStates.tips, proj);
}
// We have material, we thus need a component cache since we will
// convolute the components and later reduce them again before updating
Expand Down Expand Up @@ -292,11 +302,9 @@ struct GsfActor {
result);

if (componentCache.empty()) {
ACTS_WARNING(
"No components left after applying energy loss. "
"Is the weight cutoff "
<< m_cfg.weightCutoff << " too high?");
ACTS_WARNING("Return to propagator without applying energy loss");
ACTS_DEBUG(
"No components left after applying energy loss, stop propagation");
navigator.navigationBreak(state.navigation, true);
return;
}

Expand All @@ -307,7 +315,8 @@ struct GsfActor {

removeLowWeightComponents(componentCache);

updateStepper(state, stepper, navigator, componentCache);
auto proj = [](const auto& a) -> decltype(a) { return a; };
updateStepper(state, stepper, navigator, componentCache, proj);
}

// If we have only done preUpdate before, now do postUpdate
Expand Down Expand Up @@ -409,6 +418,12 @@ struct GsfActor {
assert(p_prev + delta_p > 0. && "new momentum must be > 0");
new_pars[eBoundQOverP] = old_bound.charge() / (p_prev + delta_p);

if (p_prev + delta_p < m_cfg.momentumCutoff) {
ACTS_VERBOSE("Skip new component with p=" << p_prev + delta_p
<< " GeV");
continue;
}

// compute inverse variance of p from mixture and update covariance
auto new_cov = old_bound.covariance().value();

Expand Down Expand Up @@ -455,50 +470,19 @@ struct GsfActor {
}
}

/// Function that updates the stepper from the MultiTrajectory
template <typename propagator_state_t, typename stepper_t>
void updateStepper(propagator_state_t& state, const stepper_t& stepper,
const TemporaryStates& tmpStates) const {
auto cmps = stepper.componentIterable(state.stepping);

for (auto [idx, cmp] : zip(tmpStates.tips, cmps)) {
// we set ignored components to missed, so we can remove them after
// the loop
if (tmpStates.weights.at(idx) < m_cfg.weightCutoff) {
cmp.status() = Intersection3D::Status::missed;
continue;
}

auto proxy = tmpStates.traj.getTrackState(idx);

cmp.pars() =
MultiTrajectoryHelpers::freeFiltered(state.options.geoContext, proxy);
cmp.cov() = proxy.filteredCovariance();
cmp.weight() = tmpStates.weights.at(idx);
}

stepper.removeMissedComponents(state.stepping);

// TODO we have two normalization passes here now, this can probably be
// optimized
detail::normalizeWeights(cmps,
[&](auto cmp) -> double& { return cmp.weight(); });
}

/// Function that updates the stepper from the ComponentCache
template <typename propagator_state_t, typename stepper_t,
typename navigator_t>
typename navigator_t, typename range_t, typename proj_t>
void updateStepper(propagator_state_t& state, const stepper_t& stepper,
const navigator_t& navigator,
const std::vector<ComponentCache>& componentCache) const {
const navigator_t& navigator, const range_t& range,
const proj_t& proj) const {
const auto& surface = *navigator.currentSurface(state.navigation);

// Clear components before adding new ones
stepper.clearComponents(state.stepping);

// Finally loop over components
for (const auto& [weight, pars, cov] : componentCache) {
// Add the component to the stepper
for (const auto& cmp : range) {
const auto& [weight, pars, cov] = proj(cmp);
BoundTrackParameters bound(surface.getSharedPtr(), pars, cov,
stepper.particleHypothesis(state.stepping));

Expand All @@ -509,12 +493,12 @@ struct GsfActor {
continue;
}

auto& cmp = *res;
cmp.jacToGlobal() = surface.boundToFreeJacobian(state.geoContext, pars);
cmp.pathAccumulated() = state.stepping.pathAccumulated;
cmp.jacobian() = Acts::BoundMatrix::Identity();
cmp.derivative() = Acts::FreeVector::Zero();
cmp.jacTransport() = Acts::FreeMatrix::Identity();
auto& proxy = *res;
proxy.jacToGlobal() = surface.boundToFreeJacobian(state.geoContext, pars);
proxy.pathAccumulated() = state.stepping.pathAccumulated;
proxy.jacobian() = Acts::BoundMatrix::Identity();
proxy.derivative() = Acts::FreeVector::Zero();
proxy.jacTransport() = Acts::FreeMatrix::Identity();
}
}

Expand All @@ -528,11 +512,15 @@ struct GsfActor {
const SourceLink& source_link) const {
const auto& surface = *navigator.currentSurface(state.navigation);

// This allows to easily project the state to weight, filtered pars,
// filtered cov
const FiltProjector proj{tmpStates.traj, tmpStates.weights};

// Boolean flag, to distinguish measurement and outlier states. This flag
// is only modified by the valid-measurement-branch, so only if there
// isn't any valid measurement state, the flag stays false and the state
// is thus counted as an outlier
bool is_valid_measurement = false;
bool isValidMeasurement = false;

auto cmps = stepper.componentIterable(state.stepping);
for (auto cmp : cmps) {
Expand All @@ -550,54 +538,73 @@ struct GsfActor {

const auto& trackStateProxy = *trackStateProxyRes;

const auto p = state.stepping.particleHypothesis.extractMomentum(
trackStateProxy.filtered()[eBoundQOverP]);
if (p < m_cfg.momentumCutoff) {
ACTS_VERBOSE("Discard component with momentum "
<< p << " GeV after Kalman update");
continue;
}

// If at least one component is no outlier, we consider the whole thing
// as a measurementState
if (trackStateProxy.typeFlags().test(
Acts::TrackStateFlag::MeasurementFlag)) {
is_valid_measurement = true;
isValidMeasurement = true;
}

tmpStates.tips.push_back(trackStateProxy.index());
tmpStates.weights[tmpStates.tips.back()] = cmp.weight();
}

computePosteriorWeights(tmpStates.traj, tmpStates.tips, tmpStates.weights);
// compute the posterior weights
if (!tmpStates.tips.empty()) {
computePosteriorWeights(tmpStates.traj, tmpStates.tips,
tmpStates.weights);

detail::normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
return tmpStates.weights.at(idx);
});
detail::normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
return tmpStates.weights.at(idx);
});
}

// Remove low weight components
auto newEnd = std::remove_if(
tmpStates.tips.begin(), tmpStates.tips.end(),
[&](auto t) { return tmpStates.weights[t] < m_cfg.weightCutoff; });

if (newEnd != tmpStates.tips.end()) {
tmpStates.tips.erase(newEnd, tmpStates.tips.end());
detail::normalizeWeights(tmpStates.tips, [&](auto idx) -> double& {
return tmpStates.weights.at(idx);
});
}

// Break navigation if we have no states left
if (tmpStates.tips.empty()) {
ACTS_DEBUG("no components left after Kalman update, break navigation!");
navigator.navigationBreak(state.navigation, true);
return Acts::Result<void>::success();
}

// Do the statistics
++result.processedStates;

// TODO should outlier states also be counted here?
if (is_valid_measurement) {
if (isValidMeasurement) {
++result.measurementStates;
}

addCombinedState(result, tmpStates, surface);
result.lastMeasurementTip = result.currentTip;

using FiltProjector =
MultiTrajectoryProjector<StatesType::eFiltered, traj_t>;
FiltProjector proj{tmpStates.traj, tmpStates.weights};

std::vector<std::tuple<double, BoundVector, BoundMatrix>> v;
updateMultiTrajectory(result, tmpStates, surface);

// TODO Check why can zero weights can occur
result.lastMeasurementTip = result.currentTip;
result.lastMeasurementSurface = &surface;
result.lastMeasurementComponents.clear();
for (const auto& idx : tmpStates.tips) {
const auto [w, p, c] = proj(idx);
if (w > 0.0) {
v.push_back({w, p, c});
}
assert(w > 0.0);
result.lastMeasurementComponents.push_back({w, p, c});
}

normalizeWeights(v, [](auto& c) -> double& { return std::get<double>(c); });

result.lastMeasurementState = MultiComponentBoundTrackParameters(
surface.getSharedPtr(), std::move(v),
stepper.particleHypothesis(state.stepping));

// Return success
return Acts::Result<void>::success();
}
Expand Down Expand Up @@ -648,7 +655,7 @@ struct GsfActor {

++result.processedStates;

addCombinedState(result, tmpStates, surface);
updateMultiTrajectory(result, tmpStates, surface);

return Result<void>::success();
}
Expand Down Expand Up @@ -694,8 +701,9 @@ struct GsfActor {
}
}

void addCombinedState(result_type& result, const TemporaryStates& tmpStates,
const Surface& surface) const {
void updateMultiTrajectory(result_type& result,
const TemporaryStates& tmpStates,
const Surface& surface) const {
using PrtProjector =
MultiTrajectoryProjector<StatesType::ePredicted, traj_t>;
using FltProjector =
Expand Down Expand Up @@ -772,6 +780,7 @@ struct GsfActor {
m_cfg.abortOnError = options.abortOnError;
m_cfg.disableAllMaterialHandling = options.disableAllMaterialHandling;
m_cfg.weightCutoff = options.weightCutoff;
m_cfg.momentumCutoff = options.momentumCutoff;
m_cfg.mergeMethod = options.componentMergeMethod;
m_cfg.calibrationContext = &options.calibrationContext.get();
}
Expand Down
3 changes: 1 addition & 2 deletions Core/include/Acts/TrackFitting/detail/GsfUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ class ScopedGsfInfoPrinterAndChecker {
assert(allFinite && "weights not finite at the start");
assert(allNormalized && "not normalized at the start");
} else {
assert(!zeroComponents && "no cmps at the end");
assert((zeroComponents || allNormalized) && "not normalized at the end");
assert(allFinite && "weights not finite at the end");
assert(allNormalized && "not normalized at the end");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ std::shared_ptr<TrackFitterFunction> makeGsfFitterFunction(
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod,
double weightCutoff, double momentumCutoff,
Acts::ComponentMergeMethod componentMergeMethod,
MixtureReductionAlgorithm mixtureReductionAlgorithm,
const Acts::Logger& logger);

Expand Down
7 changes: 6 additions & 1 deletion Examples/Algorithms/TrackFitting/src/GsfFitterFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TrackingGeometry;
} // namespace Acts

using namespace ActsExamples;
using namespace Acts::UnitLiterals;

namespace {

Expand All @@ -78,6 +79,7 @@ struct GsfFitterFunctionImpl final : public ActsExamples::TrackFitterFunction {

std::size_t maxComponents = 0;
double weightCutoff = 0;
double momentumCutoff = 0;
bool abortOnError = false;
bool disableAllMaterialHandling = false;
MixtureReductionAlgorithm reductionAlg =
Expand Down Expand Up @@ -110,6 +112,7 @@ struct GsfFitterFunctionImpl final : public ActsExamples::TrackFitterFunction {
&(*options.referenceSurface),
maxComponents,
weightCutoff,
momentumCutoff,
abortOnError,
disableAllMaterialHandling};
gsfOptions.componentMergeMethod = mergeMethod;
Expand Down Expand Up @@ -177,7 +180,8 @@ std::shared_ptr<TrackFitterFunction> ActsExamples::makeGsfFitterFunction(
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod,
double weightCutoff, double momentumCutoff,
Acts::ComponentMergeMethod componentMergeMethod,
MixtureReductionAlgorithm mixtureReductionAlgorithm,
const Acts::Logger& logger) {
// Standard fitter
Expand Down Expand Up @@ -211,6 +215,7 @@ std::shared_ptr<TrackFitterFunction> ActsExamples::makeGsfFitterFunction(
std::move(trackFitter), std::move(directTrackFitter), geo);
fitterFunction->maxComponents = maxComponents;
fitterFunction->weightCutoff = weightCutoff;
fitterFunction->momentumCutoff = momentumCutoff;
fitterFunction->mergeMethod = componentMergeMethod;
fitterFunction->reductionAlg = mixtureReductionAlgorithm;

Expand Down
1 change: 1 addition & 0 deletions Examples/Python/python/acts/examples/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ def addTruthTrackingGsf(
"componentMergeMethod": acts.examples.ComponentMergeMethod.maxWeight,
"mixtureReductionAlgorithm": acts.examples.MixtureReductionAlgorithm.KLDistance,
"weightCutoff": 1.0e-4,
"momentumCutoff": 100 * u.MeV,
"level": customLogLevel(),
}

Expand Down
Loading
Loading