From 3f3142bcca58f9a7cc6086891a3f153b1836cfa2 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Mon, 8 Jul 2024 09:18:15 +0200 Subject: [PATCH 01/10] checkout relevant changes --- .../TrackFindingExaTrkX/CMakeLists.txt | 1 + .../TrackFindingAlgorithmExaTrkX.hpp | 24 +-- .../TrackFindingExaTrkX/TruthGraphBuilder.hpp | 81 +++++++++ .../src/TrackFindingAlgorithmExaTrkX.cpp | 130 +++------------ .../src/TruthGraphBuilder.cpp | 157 ++++++++++++++++++ Examples/Python/src/ExaTrkXTrackFinding.cpp | 12 +- 6 files changed, 276 insertions(+), 129 deletions(-) create mode 100644 Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp create mode 100644 Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp diff --git a/Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt b/Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt index b4a0f280555..5b7ea5e820e 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt +++ b/Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt @@ -3,6 +3,7 @@ add_library( src/TrackFindingAlgorithmExaTrkX.cpp src/PrototracksToParameters.cpp src/TrackFindingFromPrototrackAlgorithm.cpp + src/TruthGraphBuilder.cpp ) target_include_directories( diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp index c6eff23d627..a682559e056 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp @@ -19,6 +19,7 @@ #include "ActsExamples/EventData/SimSpacePoint.hpp" #include "ActsExamples/Framework/DataHandle.hpp" #include "ActsExamples/Framework/IAlgorithm.hpp" +#include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp" #include #include @@ -42,15 +43,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { /// * cluster size in local x /// * cluster size in local y std::string inputClusters; - - /// Input simhits (Optional). - std::string inputSimHits; - /// Input measurement simhit map (Optional). - std::string inputParticles; - /// Input measurement simhit map (Optional). - std::string inputMeasurementSimhitsMap; - - /// Output protoTracks collection. + /// Input truth graph (Optional). + std::string inputTruthGraph; + /// Output prototracks std::string outputProtoTracks; /// Output graph (optional) @@ -73,10 +68,6 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { /// Remove track candidates with 2 or less hits bool filterShortTracks = false; - - /// Target graph properties - std::size_t targetMinHits = 3; - double targetMinPT = 500 * Acts::UnitConstants::MeV; }; /// Constructor of the track finding algorithm @@ -119,16 +110,11 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { "InputSpacePoints"}; ReadDataHandle m_inputClusters{this, "InputClusters"}; + ReadDataHandle m_inputTruthGraph{this, "InputTruthGraph"}; WriteDataHandle m_outputProtoTracks{this, "OutputProtoTracks"}; WriteDataHandle m_outputGraph{ this, "OutputGraph"}; - - // for truth graph - ReadDataHandle m_inputSimHits{this, "InputSimHits"}; - ReadDataHandle m_inputParticles{this, "InputParticles"}; - ReadDataHandle> m_inputMeasurementMap{ - this, "InputMeasurementMap"}; }; } // namespace ActsExamples diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp new file mode 100644 index 00000000000..6408be94d81 --- /dev/null +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp @@ -0,0 +1,81 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2022 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Definitions/Units.hpp" +#include "ActsExamples/EventData/Cluster.hpp" +#include "ActsExamples/EventData/ProtoTrack.hpp" +#include "ActsExamples/EventData/SimHit.hpp" +#include "ActsExamples/EventData/SimParticle.hpp" +#include "ActsExamples/EventData/SimSpacePoint.hpp" +#include "ActsExamples/Framework/DataHandle.hpp" +#include "ActsExamples/Framework/IAlgorithm.hpp" + +namespace ActsExamples { + +using TruthGraph = std::vector; + +/// Algorithm to create a truth graph for computing truth metrics +/// Requires spacepoints and particles collection +/// Either provide a measurements particles map, or a measurement simhit map + +/// simhits +class TruthGraphBuilder final : public IAlgorithm { + public: + struct Config { + /// Input spacepoint collection + std::string inputSpacePoints; + /// Input particles collection + std::string inputParticles; + /// Input measurement particles map (Optional). + std::string inputMeasurementParticlesMap; + /// Input simhits (Optional). + std::string inputSimHits; + /// Input measurement simhit map (Optional). + std::string inputMeasurementSimHitsMap; + /// Output truth graph + std::string outputGraph; + + double targetMinPT = 0.5; + std::size_t targetMinSize = 3; + }; + + TruthGraphBuilder(Config cfg, Acts::Logging::Level lvl); + + ~TruthGraphBuilder() override = default; + + ActsExamples::ProcessCode execute( + const ActsExamples::AlgorithmContext& ctx) const final; + + const Config& config() const { return m_cfg; } + + private: + Config m_cfg; + + TruthGraph buildFromMeasurements( + const SimSpacePointContainer& spacepoints, + const SimParticleContainer& particles, + const IndexMultimap& measPartMap) const; + + TruthGraph buildFromSimhits(const SimSpacePointContainer& spacepoints, + const IndexMultimap& measHitMap, + const SimHitContainer& simhits, + const SimParticleContainer& particles) const; + + ReadDataHandle m_inputSpacePoints{this, + "InputSpacePoints"}; + ReadDataHandle m_inputParticles{this, "InputParticles"}; + ReadDataHandle> m_inputMeasParticlesMap{ + this, "InputMeasParticlesMap"}; + ReadDataHandle m_inputSimhits{this, "InputSimhits"}; + ReadDataHandle> m_inputMeasSimhitMap{ + this, "InputMeasSimhitMap"}; + + WriteDataHandle m_outputGraph{this, "OutputGraph"}; +}; +} // namespace ActsExamples diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index 9e5c80cb881..bbc00682c6f 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -25,97 +25,16 @@ using namespace Acts::UnitLiterals; namespace { -class ExamplesEdmHook : public Acts::ExaTrkXHook { - double m_targetPT = 0.5_GeV; - std::size_t m_targetSize = 3; - - std::unique_ptr m_logger; - std::unique_ptr m_truthGraphHook; - std::unique_ptr m_targetGraphHook; - std::unique_ptr m_graphStoreHook; - - const Acts::Logger& logger() const { return *m_logger; } - - struct HitInfo { - std::size_t spacePointIndex; - std::int32_t hitIndex; - }; - - public: - ExamplesEdmHook(const SimSpacePointContainer& spacepoints, - const IndexMultimap& measHitMap, - const SimHitContainer& simhits, - const SimParticleContainer& particles, - std::size_t targetMinHits, double targetMinPT, - const Acts::Logger& logger) - : m_targetPT(targetMinPT), - m_targetSize(targetMinHits), - m_logger(logger.clone("MetricsHook")) { - // Associate tracks to graph, collect momentum - std::unordered_map> tracks; - - for (auto i = 0ul; i < spacepoints.size(); ++i) { - const auto measId = spacepoints[i] - .sourceLinks()[0] - .template get() - .index(); - - auto [a, b] = measHitMap.equal_range(measId); - for (auto it = a; it != b; ++it) { - const auto& hit = *simhits.nth(it->second); - - tracks[hit.particleId()].push_back({i, hit.index()}); - } - } - - // Collect edges for truth graph and target graph - std::vector truthGraph; - std::vector targetGraph; - - for (auto& [pid, track] : tracks) { - // Sort by hit index, so the edges are connected correctly - std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) { - return a.hitIndex < b.hitIndex; - }); - - auto found = particles.find(pid); - if (found == particles.end()) { - ACTS_WARNING("Did not find " << pid << ", skip track"); - continue; - } - - for (auto i = 0ul; i < track.size() - 1; ++i) { - truthGraph.push_back(track[i].spacePointIndex); - truthGraph.push_back(track[i + 1].spacePointIndex); - - if (found->transverseMomentum() > m_targetPT && - track.size() >= m_targetSize) { - targetGraph.push_back(track[i].spacePointIndex); - targetGraph.push_back(track[i + 1].spacePointIndex); - } - } - } +struct LoopHook : public Acts::ExaTrkXHook { + std::vector hooks; - m_truthGraphHook = std::make_unique( - truthGraph, logger.clone()); - m_targetGraphHook = std::make_unique( - targetGraph, logger.clone()); - m_graphStoreHook = std::make_unique(); - } - - ~ExamplesEdmHook() {} - - auto storedGraph() const { return m_graphStoreHook->storedGraph(); } + ~LoopHook() {} void operator()(const std::any& nodes, const std::any& edges, const std::any& weights) const override { - ACTS_INFO("Metrics for total graph:"); - (*m_truthGraphHook)(nodes, edges, weights); - ACTS_INFO("Metrics for target graph (pT > " - << m_targetPT / Acts::UnitConstants::GeV - << " GeV, nHits >= " << m_targetSize << "):"); - (*m_targetGraphHook)(nodes, edges, weights); - (*m_graphStoreHook)(nodes, edges, weights); + for (auto hook : hooks) { + (*hook)(nodes, edges, weights); + } } }; @@ -156,10 +75,7 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( m_inputClusters.maybeInitialize(m_cfg.inputClusters); m_outputProtoTracks.initialize(m_cfg.outputProtoTracks); - m_inputSimHits.maybeInitialize(m_cfg.inputSimHits); - m_inputParticles.maybeInitialize(m_cfg.inputParticles); - m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap); - + m_inputTruthGraph.maybeInitialize(m_cfg.inputTruthGraph); m_outputGraph.maybeInitialize(m_cfg.outputGraph); // reserve space for timing @@ -181,17 +97,25 @@ enum feat : std::size_t { ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( const ActsExamples::AlgorithmContext& ctx) const { - // Read input data - auto spacepoints = m_inputSpacePoints(ctx); + // Setup hooks + LoopHook hook; + + std::unique_ptr truthGraphHook; + if (m_inputTruthGraph.isInitialized()) { + truthGraphHook = std::make_unique( + m_inputTruthGraph(ctx), this->logger().clone()); + hook.hooks.push_back(&*truthGraphHook); + } - auto hook = std::make_unique(); - if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) { - hook = std::make_unique( - spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx), - m_inputParticles(ctx), m_cfg.targetMinHits, m_cfg.targetMinPT, - logger()); + std::unique_ptr graphStoreHook; + if (m_outputGraph.isInitialized()) { + graphStoreHook = std::make_unique(); + hook.hooks.push_back(&*graphStoreHook); } + // Read input data + auto spacepoints = m_inputSpacePoints(ctx); + std::optional clusters; if (m_inputClusters.isInitialized()) { clusters = m_inputClusters(ctx); @@ -253,7 +177,8 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( std::lock_guard lock(m_mutex); Acts::ExaTrkXTiming timing; - auto res = m_pipeline.run(features, spacepointIDs, *hook, &timing); + auto res = + m_pipeline.run(features, moduleIds, spacepointIDs, hook, &timing); m_timing.graphBuildingTime(timing.graphBuildingTime.count()); @@ -294,9 +219,8 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( ACTS_INFO("Created " << protoTracks.size() << " proto tracks"); m_outputProtoTracks(ctx, std::move(protoTracks)); - if (auto dhook = dynamic_cast(&*hook); - dhook && m_outputGraph.isInitialized()) { - auto graph = dhook->storedGraph(); + if (m_outputGraph.isInitialized()) { + auto graph = graphStoreHook->storedGraph(); std::transform( graph.first.begin(), graph.first.end(), graph.first.begin(), [&](const auto& a) -> std::int64_t { return spacepointIDs.at(a); }); diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp new file mode 100644 index 00000000000..b8eaf0dc668 --- /dev/null +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp @@ -0,0 +1,157 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2022 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include + +using namespace Acts::UnitLiterals; + +namespace ActsExamples { + +ActsExamples::TruthGraphBuilder::TruthGraphBuilder(Config config, + Acts::Logging::Level level) + : ActsExamples::IAlgorithm("TruthGraphBuilder", level), + m_cfg(std::move(config)) { + m_inputSpacePoints.initialize(m_cfg.inputSpacePoints); + m_inputParticles.initialize(m_cfg.inputParticles); + m_outputGraph.initialize(m_cfg.outputGraph); + + m_inputMeasParticlesMap.maybeInitialize(m_cfg.inputMeasurementParticlesMap); + m_inputSimhits.maybeInitialize(m_cfg.inputSimHits); + m_inputMeasSimhitMap.maybeInitialize(m_cfg.inputMeasurementSimHitsMap); + + bool a = m_inputMeasParticlesMap.isInitialized(); + bool b = + m_inputSimhits.isInitialized() && m_inputMeasSimhitMap.isInitialized(); + + if (!(a || b)) { + throw std::invalid_argument("Missing inputs, cannot build truth graph"); + } +} + +TruthGraph TruthGraphBuilder::buildFromMeasurements( + const SimSpacePointContainer& spacepoints, + const SimParticleContainer& particles, + const IndexMultimap& measPartMap) const { + if (m_cfg.targetMinPT < 500_MeV) { + ACTS_WARNING( + "truth graph building based on distance from origin, this breaks down " + "for low pT particles. Consider using a higher target pT value"); + } + + // Associate tracks to graph, collect momentum + std::unordered_map> tracks; + + for (auto i = 0ul; i < spacepoints.size(); ++i) { + const auto measId = + spacepoints[i].sourceLinks()[0].template get().index(); + + auto [a, b] = measPartMap.equal_range(measId); + for (auto it = a; it != b; ++it) { + tracks[it->second].push_back(i); + } + } + + // Collect edges for truth graph and target graph + std::vector graph; + + for (auto& [pid, track] : tracks) { + auto found = particles.find(pid); + if (found == particles.end()) { + ACTS_WARNING("Did not find " << pid << ", skip track"); + continue; + } + + if (found->transverseMomentum() < m_cfg.targetMinPT || + track.size() < m_cfg.targetMinSize) { + continue; + } + + // Sort by radius (this breaks down if the particle has to low momentum) + std::sort(track.begin(), track.end(), [&](const auto& a, const auto& b) { + return spacepoints[a].r() < spacepoints[b].r(); + }); + + for (auto i = 0ul; i < track.size() - 1; ++i) { + graph.push_back(track[i]); + graph.push_back(track[i + 1]); + } + } + + return graph; +} + +struct HitInfo { + std::size_t spacePointIndex; + std::int32_t hitIndex; +}; + +TruthGraph TruthGraphBuilder::buildFromSimhits( + const SimSpacePointContainer& spacepoints, + const IndexMultimap& measHitMap, const SimHitContainer& simhits, + const SimParticleContainer& particles) const { + // Associate tracks to graph, collect momentum + std::unordered_map> tracks; + + for (auto i = 0ul; i < spacepoints.size(); ++i) { + const auto measId = + spacepoints[i].sourceLinks()[0].template get().index(); + + auto [a, b] = measHitMap.equal_range(measId); + for (auto it = a; it != b; ++it) { + const auto& hit = *simhits.nth(it->second); + + tracks[hit.particleId()].push_back({i, hit.index()}); + } + } + + // Collect edges for truth graph and target graph + std::vector truthGraph; + + for (auto& [pid, track] : tracks) { + // Sort by hit index, so the edges are connected correctly + std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) { + return a.hitIndex < b.hitIndex; + }); + + auto found = particles.find(pid); + if (found == particles.end()) { + ACTS_WARNING("Did not find " << pid << ", skip track"); + continue; + } + + for (auto i = 0ul; i < track.size() - 1; ++i) { + if (found->transverseMomentum() > m_cfg.targetMinPT && + track.size() >= m_cfg.targetMinSize) { + truthGraph.push_back(track[i].spacePointIndex); + truthGraph.push_back(track[i + 1].spacePointIndex); + } + } + } + + return truthGraph; +} + +ActsExamples::ProcessCode ActsExamples::TruthGraphBuilder::execute( + const ActsExamples::AlgorithmContext& ctx) const { + // Read input data + const auto& spacepoints = m_inputSpacePoints(ctx); + const auto& particles = m_inputParticles(ctx); + + auto graph = (m_inputMeasParticlesMap.isInitialized()) + ? buildFromMeasurements(spacepoints, particles, + m_inputMeasParticlesMap(ctx)) + : buildFromSimhits(spacepoints, m_inputMeasSimhitMap(ctx), + m_inputSimhits(ctx), particles); + + m_outputGraph(ctx, std::move(graph)); + + return ProcessCode::SUCCESS; +} + +} // namespace ActsExamples diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index 1086086d006..d7fd8209d99 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -18,6 +18,7 @@ #include "ActsExamples/TrackFindingExaTrkX/PrototracksToParameters.hpp" #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp" #include "ActsExamples/TrackFindingExaTrkX/TrackFindingFromPrototrackAlgorithm.hpp" +#include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp" #include @@ -168,13 +169,10 @@ void addExaTrkXTrackFinding(Context &ctx) { #endif ACTS_PYTHON_DECLARE_ALGORITHM( - ActsExamples::TrackFindingAlgorithmExaTrkX, mex, - "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits, - inputParticles, inputClusters, inputMeasurementSimhitsMap, - outputProtoTracks, outputGraph, graphConstructor, edgeClassifiers, - trackBuilder, rScale, phiScale, zScale, cellCountScale, cellSumScale, - clusterXScale, clusterYScale, filterShortTracks, targetMinHits, - targetMinPT); + ActsExamples::TruthGraphBuilder, mex, "TruthGraphBuilder", + inputSpacePoints, inputSimHits, inputParticles, + inputMeasurementSimHitsMap, inputMeasurementParticlesMap, outputGraph, + targetMinPT, targetMinSize); { auto cls = From 0c24d90386924fc361356096db066784eb5f5b0f Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Mon, 8 Jul 2024 09:37:04 +0200 Subject: [PATCH 02/10] fix --- .../TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index bbc00682c6f..e74d0b10e37 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -178,7 +178,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( Acts::ExaTrkXTiming timing; auto res = - m_pipeline.run(features, moduleIds, spacepointIDs, hook, &timing); + m_pipeline.run(features, spacepointIDs, hook, &timing); m_timing.graphBuildingTime(timing.graphBuildingTime.count()); From 8dec410aa68d509e0396e9429ac079140a8d3125 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Mon, 8 Jul 2024 09:43:55 +0200 Subject: [PATCH 03/10] clang-format --- .../TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index e74d0b10e37..f40bb26c74d 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -177,8 +177,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( std::lock_guard lock(m_mutex); Acts::ExaTrkXTiming timing; - auto res = - m_pipeline.run(features, spacepointIDs, hook, &timing); + auto res = m_pipeline.run(features, spacepointIDs, hook, &timing); m_timing.graphBuildingTime(timing.graphBuildingTime.count()); From 1e7ea3b4ae43f3dc687df78601ad6c20d1baed90 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Tue, 9 Jul 2024 13:20:03 +0200 Subject: [PATCH 04/10] pull changes in --- .../TrackFindingAlgorithmExaTrkX.hpp | 6 +- .../TrackFindingExaTrkX/TruthGraphBuilder.hpp | 18 +++-- .../src/TrackFindingAlgorithmExaTrkX.cpp | 4 +- .../src/TruthGraphBuilder.cpp | 50 +++++++++++-- .../include/ActsExamples/EventData/Graph.hpp | 32 ++++++++ Examples/Io/Csv/CMakeLists.txt | 1 + .../Io/Csv/CsvExaTrkXGraphReader.hpp | 64 ++++++++++++++++ .../Io/Csv/CsvExaTrkXGraphWriter.hpp | 8 +- Examples/Io/Csv/src/CsvExaTrkXGraphReader.cpp | 73 +++++++++++++++++++ Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp | 29 ++++---- Examples/Io/Csv/src/CsvOutputData.hpp | 7 ++ Examples/Python/src/ExaTrkXTrackFinding.cpp | 2 +- Examples/Python/src/Input.cpp | 5 ++ 13 files changed, 260 insertions(+), 39 deletions(-) create mode 100644 Examples/Framework/include/ActsExamples/EventData/Graph.hpp create mode 100644 Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp create mode 100644 Examples/Io/Csv/src/CsvExaTrkXGraphReader.cpp diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp index a682559e056..85cfb19729d 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp @@ -13,6 +13,7 @@ #include "Acts/Plugins/ExaTrkX/Stages.hpp" #include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "ActsExamples/EventData/Cluster.hpp" +#include "ActsExamples/EventData/Graph.hpp" #include "ActsExamples/EventData/ProtoTrack.hpp" #include "ActsExamples/EventData/SimHit.hpp" #include "ActsExamples/EventData/SimParticle.hpp" @@ -110,11 +111,10 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { "InputSpacePoints"}; ReadDataHandle m_inputClusters{this, "InputClusters"}; - ReadDataHandle m_inputTruthGraph{this, "InputTruthGraph"}; + ReadDataHandle m_inputTruthGraph{this, "InputTruthGraph"}; WriteDataHandle m_outputProtoTracks{this, "OutputProtoTracks"}; - WriteDataHandle m_outputGraph{ - this, "OutputGraph"}; + WriteDataHandle m_outputGraph{this, "OutputGraph"}; }; } // namespace ActsExamples diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp index 6408be94d81..97c41acc2bb 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp @@ -10,6 +10,7 @@ #include "Acts/Definitions/Units.hpp" #include "ActsExamples/EventData/Cluster.hpp" +#include "ActsExamples/EventData/Graph.hpp" #include "ActsExamples/EventData/ProtoTrack.hpp" #include "ActsExamples/EventData/SimHit.hpp" #include "ActsExamples/EventData/SimParticle.hpp" @@ -19,8 +20,6 @@ namespace ActsExamples { -using TruthGraph = std::vector; - /// Algorithm to create a truth graph for computing truth metrics /// Requires spacepoints and particles collection /// Either provide a measurements particles map, or a measurement simhit map + @@ -43,6 +42,9 @@ class TruthGraphBuilder final : public IAlgorithm { double targetMinPT = 0.5; std::size_t targetMinSize = 3; + + /// Only allow one hit per track & module + bool uniqueModules = false; }; TruthGraphBuilder(Config cfg, Acts::Logging::Level lvl); @@ -57,15 +59,15 @@ class TruthGraphBuilder final : public IAlgorithm { private: Config m_cfg; - TruthGraph buildFromMeasurements( + std::vector buildFromMeasurements( const SimSpacePointContainer& spacepoints, const SimParticleContainer& particles, const IndexMultimap& measPartMap) const; - TruthGraph buildFromSimhits(const SimSpacePointContainer& spacepoints, - const IndexMultimap& measHitMap, - const SimHitContainer& simhits, - const SimParticleContainer& particles) const; + std::vector buildFromSimhits( + const SimSpacePointContainer& spacepoints, + const IndexMultimap& measHitMap, const SimHitContainer& simhits, + const SimParticleContainer& particles) const; ReadDataHandle m_inputSpacePoints{this, "InputSpacePoints"}; @@ -76,6 +78,6 @@ class TruthGraphBuilder final : public IAlgorithm { ReadDataHandle> m_inputMeasSimhitMap{ this, "InputMeasSimhitMap"}; - WriteDataHandle m_outputGraph{this, "OutputGraph"}; + WriteDataHandle m_outputGraph{this, "OutputGraph"}; }; } // namespace ActsExamples diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index f40bb26c74d..bb7372180fe 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -103,7 +103,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( std::unique_ptr truthGraphHook; if (m_inputTruthGraph.isInitialized()) { truthGraphHook = std::make_unique( - m_inputTruthGraph(ctx), this->logger().clone()); + m_inputTruthGraph(ctx).edges, this->logger().clone()); hook.hooks.push_back(&*truthGraphHook); } @@ -223,7 +223,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( std::transform( graph.first.begin(), graph.first.end(), graph.first.begin(), [&](const auto& a) -> std::int64_t { return spacepointIDs.at(a); }); - m_outputGraph(ctx, std::move(graph)); + m_outputGraph(ctx, {graph.first, graph.second}); } return ActsExamples::ProcessCode::SUCCESS; diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp index b8eaf0dc668..a97a5604986 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp @@ -34,7 +34,7 @@ ActsExamples::TruthGraphBuilder::TruthGraphBuilder(Config config, } } -TruthGraph TruthGraphBuilder::buildFromMeasurements( +std::vector TruthGraphBuilder::buildFromMeasurements( const SimSpacePointContainer& spacepoints, const SimParticleContainer& particles, const IndexMultimap& measPartMap) const { @@ -59,11 +59,14 @@ TruthGraph TruthGraphBuilder::buildFromMeasurements( // Collect edges for truth graph and target graph std::vector graph; + std::size_t notFoundParticles = 0; + std::size_t moduleDuplicatesRemoved = 0; for (auto& [pid, track] : tracks) { auto found = particles.find(pid); if (found == particles.end()) { - ACTS_WARNING("Did not find " << pid << ", skip track"); + ACTS_VERBOSE("Did not find " << pid << ", skip track"); + notFoundParticles++; continue; } @@ -72,17 +75,47 @@ TruthGraph TruthGraphBuilder::buildFromMeasurements( continue; } + const Acts::Vector3 vtx = found->fourPosition().segment<3>(0); + auto radiusForOrdering = [&](std::size_t i) { + const auto& sp = spacepoints[i]; + return std::hypot(sp.x() - vtx[0], sp.y() - vtx[1], sp.z() - vtx[2]); + }; + // Sort by radius (this breaks down if the particle has to low momentum) std::sort(track.begin(), track.end(), [&](const auto& a, const auto& b) { - return spacepoints[a].r() < spacepoints[b].r(); + return radiusForOrdering(a) < radiusForOrdering(b); }); + if (m_cfg.uniqueModules) { + auto newEnd = std::unique( + track.begin(), track.end(), [&](const auto& a, const auto& b) { + auto gidA = spacepoints[a] + .sourceLinks()[0] + .template get() + .geometryId(); + auto gidB = spacepoints[b] + .sourceLinks()[0] + .template get() + .geometryId(); + return gidA == gidB; + }); + moduleDuplicatesRemoved += std::distance(newEnd, track.end()); + track.erase(newEnd, track.end()); + } + for (auto i = 0ul; i < track.size() - 1; ++i) { graph.push_back(track[i]); graph.push_back(track[i + 1]); } } + ACTS_DEBUG("Did not find particles for " << notFoundParticles << " tracks"); + if (moduleDuplicatesRemoved > 0) { + ACTS_DEBUG( + "Removed " << moduleDuplicatesRemoved + << " hit to ensure a unique hit per track and module"); + } + return graph; } @@ -91,7 +124,7 @@ struct HitInfo { std::int32_t hitIndex; }; -TruthGraph TruthGraphBuilder::buildFromSimhits( +std::vector TruthGraphBuilder::buildFromSimhits( const SimSpacePointContainer& spacepoints, const IndexMultimap& measHitMap, const SimHitContainer& simhits, const SimParticleContainer& particles) const { @@ -143,13 +176,18 @@ ActsExamples::ProcessCode ActsExamples::TruthGraphBuilder::execute( const auto& spacepoints = m_inputSpacePoints(ctx); const auto& particles = m_inputParticles(ctx); - auto graph = (m_inputMeasParticlesMap.isInitialized()) + auto edges = (m_inputMeasParticlesMap.isInitialized()) ? buildFromMeasurements(spacepoints, particles, m_inputMeasParticlesMap(ctx)) : buildFromSimhits(spacepoints, m_inputMeasSimhitMap(ctx), m_inputSimhits(ctx), particles); - m_outputGraph(ctx, std::move(graph)); + ACTS_DEBUG("Truth track edges: " << edges.size() / 2); + + Graph g; + g.edges = std::move(edges); + + m_outputGraph(ctx, std::move(g)); return ProcessCode::SUCCESS; } diff --git a/Examples/Framework/include/ActsExamples/EventData/Graph.hpp b/Examples/Framework/include/ActsExamples/EventData/Graph.hpp new file mode 100644 index 00000000000..5508958ccaa --- /dev/null +++ b/Examples/Framework/include/ActsExamples/EventData/Graph.hpp @@ -0,0 +1,32 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2024 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include +#include + +namespace ActsExamples { + +/// Lightweight graph representation for GNN debugging +/// +struct Graph { + /// The edges-vector contains flattened edge-pairs.Usually, the indices + /// refer to a spacepoint collection. + /// + /// Structure: [ source0, dest0, source1, dest1, ..., sourceN, destN ] + std::vector edges; + + /// The weight-vector should have half the size of the edges-vector (or + /// be empty if missing). + /// + /// Structure: [ weight0, weight1, ..., weightN ] + std::vector weights; +}; + +} // namespace ActsExamples diff --git a/Examples/Io/Csv/CMakeLists.txt b/Examples/Io/Csv/CMakeLists.txt index a2479cc9206..7d367d68892 100644 --- a/Examples/Io/Csv/CMakeLists.txt +++ b/Examples/Io/Csv/CMakeLists.txt @@ -19,6 +19,7 @@ add_library( src/CsvProtoTrackWriter.cpp src/CsvSpacePointWriter.cpp src/CsvExaTrkXGraphWriter.cpp + src/CsvExaTrkXGraphReader.cpp src/CsvBFieldWriter.cpp) target_include_directories( ActsExamplesIoCsv diff --git a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp new file mode 100644 index 00000000000..4b41d57a448 --- /dev/null +++ b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp @@ -0,0 +1,64 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2022 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/Graph.hpp" +#include "ActsExamples/Framework/DataHandle.hpp" +#include "ActsExamples/Framework/IReader.hpp" +#include "ActsExamples/Framework/ProcessCode.hpp" + +#include +#include +#include +#include +#include + +namespace ActsExamples { +struct AlgorithmContext; + +class CsvExaTrkXGraphReader final : public IReader { + public: + struct Config { + /// Where to read input files from. + std::string inputDir; + /// Input filename stem. + std::string inputStem; + /// Which vs collection to read into. + std::string outputGraph; + }; + + /// Construct the particle reader. + /// + /// @param config is the configuration object + /// @param level is the logging level + CsvExaTrkXGraphReader(const Config& config, Acts::Logging::Level level); + + /// Return the available events range. + std::pair availableEvents() const override; + + /// Read out data from the input stream. + ProcessCode read(const ActsExamples::AlgorithmContext& ctx) override; + + /// Return the name of the component + std::string name() const override { return "CsvExaTrkXGraphReader"; } + + /// Readonly access to the config + const Config& config() const { return m_cfg; } + + private: + Config m_cfg; + std::pair m_eventsRange; + std::unique_ptr m_logger; + + WriteDataHandle m_outputGraph{this, "OutputGraph"}; + const Acts::Logger& logger() const { return *m_logger; } +}; + +} // namespace ActsExamples diff --git a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp index eaf88c97a1d..15a2f7bb182 100644 --- a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp +++ b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp @@ -9,6 +9,7 @@ #pragma once #include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/Graph.hpp" #include "ActsExamples/Framework/ProcessCode.hpp" #include "ActsExamples/Framework/WriterT.hpp" #include "ActsExamples/Utilities/Paths.hpp" @@ -20,8 +21,7 @@ namespace ActsExamples { struct AlgorithmContext; -class CsvExaTrkXGraphWriter final - : public WriterT, std::vector>> { +class CsvExaTrkXGraphWriter final : public WriterT { public: struct Config { /// Which simulated (truth) hits collection to use. @@ -46,9 +46,7 @@ class CsvExaTrkXGraphWriter final /// /// @param[in] ctx is the algorithm context /// @param[in] simHits are the simhits to be written - ProcessCode writeT(const AlgorithmContext& ctx, - const std::pair, - std::vector>& graph) override; + ProcessCode writeT(const AlgorithmContext& ctx, const Graph& graph) override; private: Config m_cfg; diff --git a/Examples/Io/Csv/src/CsvExaTrkXGraphReader.cpp b/Examples/Io/Csv/src/CsvExaTrkXGraphReader.cpp new file mode 100644 index 00000000000..81eaad7eb51 --- /dev/null +++ b/Examples/Io/Csv/src/CsvExaTrkXGraphReader.cpp @@ -0,0 +1,73 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2017 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp" + +#include "Acts/Definitions/PdgParticle.hpp" +#include "Acts/Definitions/Units.hpp" +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/SimParticle.hpp" +#include "ActsExamples/Framework/AlgorithmContext.hpp" +#include "ActsExamples/Utilities/Paths.hpp" +#include "ActsFatras/EventData/Barcode.hpp" +#include "ActsFatras/EventData/Particle.hpp" +#include "ActsFatras/EventData/ProcessType.hpp" +#include + +#include +#include +#include +#include + +#include + +#include "CsvOutputData.hpp" + +namespace ActsExamples { + +CsvExaTrkXGraphReader::CsvExaTrkXGraphReader(const Config& config, + Acts::Logging::Level level) + : m_cfg(config), + m_eventsRange( + determineEventFilesRange(m_cfg.inputDir, m_cfg.inputStem + ".csv")), + m_logger(Acts::getDefaultLogger("CsvExaTrkXGraphReader", level)) { + if (m_cfg.inputStem.empty()) { + throw std::invalid_argument("Missing input filename stem"); + } + + m_outputGraph.initialize(m_cfg.outputGraph); +} + +std::pair CsvExaTrkXGraphReader::availableEvents() + const { + return m_eventsRange; +} + +ProcessCode CsvExaTrkXGraphReader::read(const AlgorithmContext& ctx) { + SimParticleContainer::sequence_type unordered; + + auto path = perEventFilepath(m_cfg.inputDir, m_cfg.inputStem + ".csv", + ctx.eventNumber); + // vt and m are an optional columns + dfe::NamedTupleCsvReader reader(path, {"vt", "m"}); + GraphData data; + + Graph g; + + while (reader.read(data)) { + g.edges.push_back(data.edge0); + g.edges.push_back(data.edge1); + g.weights.push_back(data.weight); + } + + m_outputGraph(ctx, std::move(g)); + + return ProcessCode::SUCCESS; +} + +} // namespace ActsExamples diff --git a/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp index fd7e0da98c6..49c394b7ed2 100644 --- a/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp +++ b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp @@ -21,12 +21,7 @@ #include #include -struct GraphData { - std::int64_t edge0; - std::int64_t edge1; - float weight; - DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight); -}; +#include "CsvOutputData.hpp" ActsExamples::CsvExaTrkXGraphWriter::CsvExaTrkXGraphWriter( const ActsExamples::CsvExaTrkXGraphWriter::Config& config, @@ -35,20 +30,26 @@ ActsExamples::CsvExaTrkXGraphWriter::CsvExaTrkXGraphWriter( m_cfg(config) {} ActsExamples::ProcessCode ActsExamples::CsvExaTrkXGraphWriter::writeT( - const ActsExamples::AlgorithmContext& ctx, - const std::pair, std::vector>& graph) { + const ActsExamples::AlgorithmContext& ctx, const Graph& graph) { + assert(graph.weights.empty() || + (graph.edges.size() / 2 == graph.weights.size())); + assert(graph.edges.size() % 2 == 0); + + if (graph.weights.empty()) { + ACTS_DEBUG("No weights provide, write default value of 1"); + } + std::string path = perEventFilepath( m_cfg.outputDir, m_cfg.outputStem + ".csv", ctx.eventNumber); dfe::NamedTupleCsvWriter writer(path); - const auto& [edges, weights] = graph; - - for (auto i = 0ul; i < weights.size(); ++i) { + const auto nEdges = graph.edges.size() / 2; + for (auto i = 0ul; i < nEdges; ++i) { GraphData edge{}; - edge.edge0 = edges[2 * i]; - edge.edge1 = edges[2 * i + 1]; - edge.weight = weights[i]; + edge.edge0 = graph.edges[2 * i]; + edge.edge1 = graph.edges[2 * i + 1]; + edge.weight = graph.weights.empty() ? 1.f : graph.weights[i]; writer.append(edge); } diff --git a/Examples/Io/Csv/src/CsvOutputData.hpp b/Examples/Io/Csv/src/CsvOutputData.hpp index 925b8ff7719..e88109f612b 100644 --- a/Examples/Io/Csv/src/CsvOutputData.hpp +++ b/Examples/Io/Csv/src/CsvOutputData.hpp @@ -349,4 +349,11 @@ struct ProtoTrackData { DFE_NAMEDTUPLE(ProtoTrackData, trackId, measurementId, x, y, z); }; +struct GraphData { + std::int64_t edge0; + std::int64_t edge1; + float weight; + DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight); +}; + } // namespace ActsExamples diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index d7fd8209d99..2ae72f0ded4 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -172,7 +172,7 @@ void addExaTrkXTrackFinding(Context &ctx) { ActsExamples::TruthGraphBuilder, mex, "TruthGraphBuilder", inputSpacePoints, inputSimHits, inputParticles, inputMeasurementSimHitsMap, inputMeasurementParticlesMap, outputGraph, - targetMinPT, targetMinSize); + targetMinPT, targetMinSize, uniqueModules); { auto cls = diff --git a/Examples/Python/src/Input.cpp b/Examples/Python/src/Input.cpp index ddbd33ede4f..2036e51674a 100644 --- a/Examples/Python/src/Input.cpp +++ b/Examples/Python/src/Input.cpp @@ -9,6 +9,7 @@ #include "Acts/Plugins/Python/Utilities.hpp" #include "ActsExamples/EventData/Cluster.hpp" #include "ActsExamples/Io/Csv/CsvDriftCircleReader.hpp" +#include "ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp" #include "ActsExamples/Io/Csv/CsvMeasurementReader.hpp" #include "ActsExamples/Io/Csv/CsvMuonSimHitReader.hpp" #include "ActsExamples/Io/Csv/CsvParticleReader.hpp" @@ -97,6 +98,10 @@ void addInput(Context& ctx) { ACTS_PYTHON_DECLARE_READER(ActsExamples::RootSimHitReader, mex, "RootSimHitReader", treeName, filePath, outputSimHits); + + ACTS_PYTHON_DECLARE_READER(ActsExamples::CsvExaTrkXGraphReader, mex, + "CsvExaTrkXGraphReader", inputDir, inputStem, + outputGraph); } } // namespace Acts::Python From cf7a23191527881629bc6d9e48fcddecf41bdbb7 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Tue, 9 Jul 2024 13:54:24 +0200 Subject: [PATCH 05/10] make clang tidy happy --- Examples/Io/Csv/src/CsvOutputData.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Examples/Io/Csv/src/CsvOutputData.hpp b/Examples/Io/Csv/src/CsvOutputData.hpp index e88109f612b..5cff2f69fa1 100644 --- a/Examples/Io/Csv/src/CsvOutputData.hpp +++ b/Examples/Io/Csv/src/CsvOutputData.hpp @@ -350,9 +350,9 @@ struct ProtoTrackData { }; struct GraphData { - std::int64_t edge0; - std::int64_t edge1; - float weight; + std::int64_t edge0 = 0; + std::int64_t edge1 = 0; + float weight = 0.0; DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight); }; From e65f098a5c94f3ba2dd66f8c6e43ef1d10f24009 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Tue, 9 Jul 2024 14:04:21 +0200 Subject: [PATCH 06/10] fix --- Examples/Python/src/ExaTrkXTrackFinding.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index 2ae72f0ded4..171284ded1b 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -168,6 +168,13 @@ void addExaTrkXTrackFinding(Context &ctx) { } #endif + ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::TrackFindingAlgorithmExaTrkX, mex, + "TrackFindingAlgorithmExaTrkX", + inputSpacePoints, inputClusters, + inputTruthGraph, outputProtoTracks, outputGraph, + graphConstructor, edgeClassifiers, trackBuilder, + nodeFeatures, featureScales, filterShortTracks); + ACTS_PYTHON_DECLARE_ALGORITHM( ActsExamples::TruthGraphBuilder, mex, "TruthGraphBuilder", inputSpacePoints, inputSimHits, inputParticles, From 2210f80ead833e3a8929c2cc0d0b840c63d35e0b Mon Sep 17 00:00:00 2001 From: Benjamin Huth <37871400+benjaminhuth@users.noreply.github.com> Date: Tue, 9 Jul 2024 14:42:15 +0200 Subject: [PATCH 07/10] Apply suggestions from code review Co-authored-by: Andreas Stefl --- .../Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp | 4 ++-- Examples/Framework/include/ActsExamples/EventData/Graph.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp index a97a5604986..eef265a24ce 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp @@ -13,7 +13,7 @@ using namespace Acts::UnitLiterals; namespace ActsExamples { -ActsExamples::TruthGraphBuilder::TruthGraphBuilder(Config config, +TruthGraphBuilder::TruthGraphBuilder(Config config, Acts::Logging::Level level) : ActsExamples::IAlgorithm("TruthGraphBuilder", level), m_cfg(std::move(config)) { @@ -170,7 +170,7 @@ std::vector TruthGraphBuilder::buildFromSimhits( return truthGraph; } -ActsExamples::ProcessCode ActsExamples::TruthGraphBuilder::execute( +ProcessCode TruthGraphBuilder::execute( const ActsExamples::AlgorithmContext& ctx) const { // Read input data const auto& spacepoints = m_inputSpacePoints(ctx); diff --git a/Examples/Framework/include/ActsExamples/EventData/Graph.hpp b/Examples/Framework/include/ActsExamples/EventData/Graph.hpp index 5508958ccaa..5014f6548c4 100644 --- a/Examples/Framework/include/ActsExamples/EventData/Graph.hpp +++ b/Examples/Framework/include/ActsExamples/EventData/Graph.hpp @@ -16,7 +16,7 @@ namespace ActsExamples { /// Lightweight graph representation for GNN debugging /// struct Graph { - /// The edges-vector contains flattened edge-pairs.Usually, the indices + /// The edges-vector contains flattened edge-pairs. Usually, the indices /// refer to a spacepoint collection. /// /// Structure: [ source0, dest0, source1, dest1, ..., sourceN, destN ] From ed4b14cacd749f5c66cadb8de7c36b3c14028631 Mon Sep 17 00:00:00 2001 From: Benjamin Huth <37871400+benjaminhuth@users.noreply.github.com> Date: Thu, 11 Jul 2024 13:45:21 +0200 Subject: [PATCH 08/10] Update TruthGraphBuilder.cpp --- .../Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp index eef265a24ce..42acfdd9a41 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp @@ -29,7 +29,8 @@ TruthGraphBuilder::TruthGraphBuilder(Config config, bool b = m_inputSimhits.isInitialized() && m_inputMeasSimhitMap.isInitialized(); - if (!(a || b)) { + // Logical XOR operation + if (!a != !b) { throw std::invalid_argument("Missing inputs, cannot build truth graph"); } } From ba17b2ea1b176cb6d6b1667665401eac6cfbf268 Mon Sep 17 00:00:00 2001 From: Benjamin Huth <37871400+benjaminhuth@users.noreply.github.com> Date: Thu, 11 Jul 2024 13:49:07 +0200 Subject: [PATCH 09/10] Update ExaTrkXTrackFinding.cpp --- Examples/Python/src/ExaTrkXTrackFinding.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index a3cb6b46549..e25c680ff6f 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -168,13 +168,6 @@ void addExaTrkXTrackFinding(Context &ctx) { } #endif - ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::TrackFindingAlgorithmExaTrkX, mex, - "TrackFindingAlgorithmExaTrkX", - inputSpacePoints, inputClusters, - inputTruthGraph, outputProtoTracks, outputGraph, - graphConstructor, edgeClassifiers, trackBuilder, - nodeFeatures, featureScales, filterShortTracks); - ACTS_PYTHON_DECLARE_ALGORITHM( ActsExamples::TruthGraphBuilder, mex, "TruthGraphBuilder", inputSpacePoints, inputSimHits, inputParticles, @@ -205,6 +198,13 @@ void addExaTrkXTrackFinding(Context &ctx) { .value("Cluster2Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster2Eta); + ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::TrackFindingAlgorithmExaTrkX, mex, + "TrackFindingAlgorithmExaTrkX", + inputSpacePoints, inputClusters, + inputTruthGraph, outputProtoTracks, outputGraph, + graphConstructor, edgeClassifiers, trackBuilder, + nodeFeatures, featureScales, filterShortTracks); + { auto cls = py::class_>( From 5512d975a5078460950c91cf789467b227f8f4d1 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Thu, 11 Jul 2024 13:55:28 +0200 Subject: [PATCH 10/10] fix format --- .../Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp | 3 +-- Examples/Python/src/ExaTrkXTrackFinding.cpp | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp index 42acfdd9a41..4d8df390e50 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp @@ -13,8 +13,7 @@ using namespace Acts::UnitLiterals; namespace ActsExamples { -TruthGraphBuilder::TruthGraphBuilder(Config config, - Acts::Logging::Level level) +TruthGraphBuilder::TruthGraphBuilder(Config config, Acts::Logging::Level level) : ActsExamples::IAlgorithm("TruthGraphBuilder", level), m_cfg(std::move(config)) { m_inputSpacePoints.initialize(m_cfg.inputSpacePoints); diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index e25c680ff6f..2d9d0341362 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -173,7 +173,7 @@ void addExaTrkXTrackFinding(Context &ctx) { inputSpacePoints, inputSimHits, inputParticles, inputMeasurementSimHitsMap, inputMeasurementParticlesMap, outputGraph, targetMinPT, targetMinSize, uniqueModules); - + py::enum_(mex, "NodeFeature") .value("R", TrackFindingAlgorithmExaTrkX::NodeFeature::eR) .value("Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::ePhi)