diff --git a/Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt b/Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt index b4a0f280555..5b7ea5e820e 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt +++ b/Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt @@ -3,6 +3,7 @@ add_library( src/TrackFindingAlgorithmExaTrkX.cpp src/PrototracksToParameters.cpp src/TrackFindingFromPrototrackAlgorithm.cpp + src/TruthGraphBuilder.cpp ) target_include_directories( diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp index 29714b2f57a..18f51b5e99c 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp @@ -13,12 +13,14 @@ #include "Acts/Plugins/ExaTrkX/Stages.hpp" #include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "ActsExamples/EventData/Cluster.hpp" +#include "ActsExamples/EventData/Graph.hpp" #include "ActsExamples/EventData/ProtoTrack.hpp" #include "ActsExamples/EventData/SimHit.hpp" #include "ActsExamples/EventData/SimParticle.hpp" #include "ActsExamples/EventData/SimSpacePoint.hpp" #include "ActsExamples/Framework/DataHandle.hpp" #include "ActsExamples/Framework/IAlgorithm.hpp" +#include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp" #include #include @@ -57,15 +59,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { std::string inputSpacePoints; /// Input cluster information (Optional). std::string inputClusters; - - /// Input simhits (Optional). - std::string inputSimHits; - /// Input measurement simhit map (Optional). - std::string inputParticles; - /// Input measurement simhit map (Optional). - std::string inputMeasurementSimhitsMap; - - /// Output protoTracks collection. + /// Input truth graph (Optional). + std::string inputTruthGraph; + /// Output prototracks std::string outputProtoTracks; /// Output graph (optional) @@ -86,10 +82,6 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { /// Remove track candidates with 2 or less hits bool filterShortTracks = false; - - /// Target graph properties - std::size_t targetMinHits = 3; - double targetMinPT = 500 * Acts::UnitConstants::MeV; }; /// Constructor of the track finding algorithm @@ -132,16 +124,10 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { "InputSpacePoints"}; ReadDataHandle m_inputClusters{this, "InputClusters"}; + ReadDataHandle m_inputTruthGraph{this, "InputTruthGraph"}; WriteDataHandle m_outputProtoTracks{this, "OutputProtoTracks"}; - WriteDataHandle m_outputGraph{ - this, "OutputGraph"}; - - // for truth graph - ReadDataHandle m_inputSimHits{this, "InputSimHits"}; - ReadDataHandle m_inputParticles{this, "InputParticles"}; - ReadDataHandle> m_inputMeasurementMap{ - this, "InputMeasurementMap"}; + WriteDataHandle m_outputGraph{this, "OutputGraph"}; }; } // namespace ActsExamples diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp new file mode 100644 index 00000000000..97c41acc2bb --- /dev/null +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp @@ -0,0 +1,83 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2022 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Definitions/Units.hpp" +#include "ActsExamples/EventData/Cluster.hpp" +#include "ActsExamples/EventData/Graph.hpp" +#include "ActsExamples/EventData/ProtoTrack.hpp" +#include "ActsExamples/EventData/SimHit.hpp" +#include "ActsExamples/EventData/SimParticle.hpp" +#include "ActsExamples/EventData/SimSpacePoint.hpp" +#include "ActsExamples/Framework/DataHandle.hpp" +#include "ActsExamples/Framework/IAlgorithm.hpp" + +namespace ActsExamples { + +/// Algorithm to create a truth graph for computing truth metrics +/// Requires spacepoints and particles collection +/// Either provide a measurements particles map, or a measurement simhit map + +/// simhits +class TruthGraphBuilder final : public IAlgorithm { + public: + struct Config { + /// Input spacepoint collection + std::string inputSpacePoints; + /// Input particles collection + std::string inputParticles; + /// Input measurement particles map (Optional). + std::string inputMeasurementParticlesMap; + /// Input simhits (Optional). + std::string inputSimHits; + /// Input measurement simhit map (Optional). + std::string inputMeasurementSimHitsMap; + /// Output truth graph + std::string outputGraph; + + double targetMinPT = 0.5; + std::size_t targetMinSize = 3; + + /// Only allow one hit per track & module + bool uniqueModules = false; + }; + + TruthGraphBuilder(Config cfg, Acts::Logging::Level lvl); + + ~TruthGraphBuilder() override = default; + + ActsExamples::ProcessCode execute( + const ActsExamples::AlgorithmContext& ctx) const final; + + const Config& config() const { return m_cfg; } + + private: + Config m_cfg; + + std::vector buildFromMeasurements( + const SimSpacePointContainer& spacepoints, + const SimParticleContainer& particles, + const IndexMultimap& measPartMap) const; + + std::vector buildFromSimhits( + const SimSpacePointContainer& spacepoints, + const IndexMultimap& measHitMap, const SimHitContainer& simhits, + const SimParticleContainer& particles) const; + + ReadDataHandle m_inputSpacePoints{this, + "InputSpacePoints"}; + ReadDataHandle m_inputParticles{this, "InputParticles"}; + ReadDataHandle> m_inputMeasParticlesMap{ + this, "InputMeasParticlesMap"}; + ReadDataHandle m_inputSimhits{this, "InputSimhits"}; + ReadDataHandle> m_inputMeasSimhitMap{ + this, "InputMeasSimhitMap"}; + + WriteDataHandle m_outputGraph{this, "OutputGraph"}; +}; +} // namespace ActsExamples diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index 33c5a6e1f91..6747097c88f 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -25,97 +25,16 @@ using namespace Acts::UnitLiterals; namespace { -class ExamplesEdmHook : public Acts::ExaTrkXHook { - double m_targetPT = 0.5_GeV; - std::size_t m_targetSize = 3; - - std::unique_ptr m_logger; - std::unique_ptr m_truthGraphHook; - std::unique_ptr m_targetGraphHook; - std::unique_ptr m_graphStoreHook; - - const Acts::Logger& logger() const { return *m_logger; } - - struct HitInfo { - std::size_t spacePointIndex; - std::int32_t hitIndex; - }; - - public: - ExamplesEdmHook(const SimSpacePointContainer& spacepoints, - const IndexMultimap& measHitMap, - const SimHitContainer& simhits, - const SimParticleContainer& particles, - std::size_t targetMinHits, double targetMinPT, - const Acts::Logger& logger) - : m_targetPT(targetMinPT), - m_targetSize(targetMinHits), - m_logger(logger.clone("MetricsHook")) { - // Associate tracks to graph, collect momentum - std::unordered_map> tracks; - - for (auto i = 0ul; i < spacepoints.size(); ++i) { - const auto measId = spacepoints[i] - .sourceLinks()[0] - .template get() - .index(); - - auto [a, b] = measHitMap.equal_range(measId); - for (auto it = a; it != b; ++it) { - const auto& hit = *simhits.nth(it->second); - - tracks[hit.particleId()].push_back({i, hit.index()}); - } - } - - // Collect edges for truth graph and target graph - std::vector truthGraph; - std::vector targetGraph; - - for (auto& [pid, track] : tracks) { - // Sort by hit index, so the edges are connected correctly - std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) { - return a.hitIndex < b.hitIndex; - }); - - auto found = particles.find(pid); - if (found == particles.end()) { - ACTS_WARNING("Did not find " << pid << ", skip track"); - continue; - } - - for (auto i = 0ul; i < track.size() - 1; ++i) { - truthGraph.push_back(track[i].spacePointIndex); - truthGraph.push_back(track[i + 1].spacePointIndex); - - if (found->transverseMomentum() > m_targetPT && - track.size() >= m_targetSize) { - targetGraph.push_back(track[i].spacePointIndex); - targetGraph.push_back(track[i + 1].spacePointIndex); - } - } - } +struct LoopHook : public Acts::ExaTrkXHook { + std::vector hooks; - m_truthGraphHook = std::make_unique( - truthGraph, logger.clone()); - m_targetGraphHook = std::make_unique( - targetGraph, logger.clone()); - m_graphStoreHook = std::make_unique(); - } - - ~ExamplesEdmHook() {} - - auto storedGraph() const { return m_graphStoreHook->storedGraph(); } + ~LoopHook() {} void operator()(const std::any& nodes, const std::any& edges, const std::any& weights) const override { - ACTS_INFO("Metrics for total graph:"); - (*m_truthGraphHook)(nodes, edges, weights); - ACTS_INFO("Metrics for target graph (pT > " - << m_targetPT / Acts::UnitConstants::GeV - << " GeV, nHits >= " << m_targetSize << "):"); - (*m_targetGraphHook)(nodes, edges, weights); - (*m_graphStoreHook)(nodes, edges, weights); + for (auto hook : hooks) { + (*hook)(nodes, edges, weights); + } } }; @@ -164,10 +83,7 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( m_inputClusters.maybeInitialize(m_cfg.inputClusters); m_outputProtoTracks.initialize(m_cfg.outputProtoTracks); - m_inputSimHits.maybeInitialize(m_cfg.inputSimHits); - m_inputParticles.maybeInitialize(m_cfg.inputParticles); - m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap); - + m_inputTruthGraph.maybeInitialize(m_cfg.inputTruthGraph); m_outputGraph.maybeInitialize(m_cfg.outputGraph); // reserve space for timing @@ -200,17 +116,25 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( const ActsExamples::AlgorithmContext& ctx) const { - // Read input data - auto spacepoints = m_inputSpacePoints(ctx); + // Setup hooks + LoopHook hook; + + std::unique_ptr truthGraphHook; + if (m_inputTruthGraph.isInitialized()) { + truthGraphHook = std::make_unique( + m_inputTruthGraph(ctx).edges, this->logger().clone()); + hook.hooks.push_back(&*truthGraphHook); + } - auto hook = std::make_unique(); - if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) { - hook = std::make_unique( - spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx), - m_inputParticles(ctx), m_cfg.targetMinHits, m_cfg.targetMinPT, - logger()); + std::unique_ptr graphStoreHook; + if (m_outputGraph.isInitialized()) { + graphStoreHook = std::make_unique(); + hook.hooks.push_back(&*graphStoreHook); } + // Read input data + auto spacepoints = m_inputSpacePoints(ctx); + std::optional clusters; if (m_inputClusters.isInitialized()) { clusters = m_inputClusters(ctx); @@ -288,7 +212,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( std::lock_guard lock(m_mutex); Acts::ExaTrkXTiming timing; - auto res = m_pipeline.run(features, spacepointIDs, *hook, &timing); + auto res = m_pipeline.run(features, spacepointIDs, hook, &timing); m_timing.graphBuildingTime(timing.graphBuildingTime.count()); @@ -329,13 +253,12 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( ACTS_INFO("Created " << protoTracks.size() << " proto tracks"); m_outputProtoTracks(ctx, std::move(protoTracks)); - if (auto dhook = dynamic_cast(&*hook); - dhook && m_outputGraph.isInitialized()) { - auto graph = dhook->storedGraph(); + if (m_outputGraph.isInitialized()) { + auto graph = graphStoreHook->storedGraph(); std::transform( graph.first.begin(), graph.first.end(), graph.first.begin(), [&](const auto& a) -> std::int64_t { return spacepointIDs.at(a); }); - m_outputGraph(ctx, std::move(graph)); + m_outputGraph(ctx, {graph.first, graph.second}); } return ActsExamples::ProcessCode::SUCCESS; diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp new file mode 100644 index 00000000000..4d8df390e50 --- /dev/null +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TruthGraphBuilder.cpp @@ -0,0 +1,195 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2022 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include + +using namespace Acts::UnitLiterals; + +namespace ActsExamples { + +TruthGraphBuilder::TruthGraphBuilder(Config config, Acts::Logging::Level level) + : ActsExamples::IAlgorithm("TruthGraphBuilder", level), + m_cfg(std::move(config)) { + m_inputSpacePoints.initialize(m_cfg.inputSpacePoints); + m_inputParticles.initialize(m_cfg.inputParticles); + m_outputGraph.initialize(m_cfg.outputGraph); + + m_inputMeasParticlesMap.maybeInitialize(m_cfg.inputMeasurementParticlesMap); + m_inputSimhits.maybeInitialize(m_cfg.inputSimHits); + m_inputMeasSimhitMap.maybeInitialize(m_cfg.inputMeasurementSimHitsMap); + + bool a = m_inputMeasParticlesMap.isInitialized(); + bool b = + m_inputSimhits.isInitialized() && m_inputMeasSimhitMap.isInitialized(); + + // Logical XOR operation + if (!a != !b) { + throw std::invalid_argument("Missing inputs, cannot build truth graph"); + } +} + +std::vector TruthGraphBuilder::buildFromMeasurements( + const SimSpacePointContainer& spacepoints, + const SimParticleContainer& particles, + const IndexMultimap& measPartMap) const { + if (m_cfg.targetMinPT < 500_MeV) { + ACTS_WARNING( + "truth graph building based on distance from origin, this breaks down " + "for low pT particles. Consider using a higher target pT value"); + } + + // Associate tracks to graph, collect momentum + std::unordered_map> tracks; + + for (auto i = 0ul; i < spacepoints.size(); ++i) { + const auto measId = + spacepoints[i].sourceLinks()[0].template get().index(); + + auto [a, b] = measPartMap.equal_range(measId); + for (auto it = a; it != b; ++it) { + tracks[it->second].push_back(i); + } + } + + // Collect edges for truth graph and target graph + std::vector graph; + std::size_t notFoundParticles = 0; + std::size_t moduleDuplicatesRemoved = 0; + + for (auto& [pid, track] : tracks) { + auto found = particles.find(pid); + if (found == particles.end()) { + ACTS_VERBOSE("Did not find " << pid << ", skip track"); + notFoundParticles++; + continue; + } + + if (found->transverseMomentum() < m_cfg.targetMinPT || + track.size() < m_cfg.targetMinSize) { + continue; + } + + const Acts::Vector3 vtx = found->fourPosition().segment<3>(0); + auto radiusForOrdering = [&](std::size_t i) { + const auto& sp = spacepoints[i]; + return std::hypot(sp.x() - vtx[0], sp.y() - vtx[1], sp.z() - vtx[2]); + }; + + // Sort by radius (this breaks down if the particle has to low momentum) + std::sort(track.begin(), track.end(), [&](const auto& a, const auto& b) { + return radiusForOrdering(a) < radiusForOrdering(b); + }); + + if (m_cfg.uniqueModules) { + auto newEnd = std::unique( + track.begin(), track.end(), [&](const auto& a, const auto& b) { + auto gidA = spacepoints[a] + .sourceLinks()[0] + .template get() + .geometryId(); + auto gidB = spacepoints[b] + .sourceLinks()[0] + .template get() + .geometryId(); + return gidA == gidB; + }); + moduleDuplicatesRemoved += std::distance(newEnd, track.end()); + track.erase(newEnd, track.end()); + } + + for (auto i = 0ul; i < track.size() - 1; ++i) { + graph.push_back(track[i]); + graph.push_back(track[i + 1]); + } + } + + ACTS_DEBUG("Did not find particles for " << notFoundParticles << " tracks"); + if (moduleDuplicatesRemoved > 0) { + ACTS_DEBUG( + "Removed " << moduleDuplicatesRemoved + << " hit to ensure a unique hit per track and module"); + } + + return graph; +} + +struct HitInfo { + std::size_t spacePointIndex; + std::int32_t hitIndex; +}; + +std::vector TruthGraphBuilder::buildFromSimhits( + const SimSpacePointContainer& spacepoints, + const IndexMultimap& measHitMap, const SimHitContainer& simhits, + const SimParticleContainer& particles) const { + // Associate tracks to graph, collect momentum + std::unordered_map> tracks; + + for (auto i = 0ul; i < spacepoints.size(); ++i) { + const auto measId = + spacepoints[i].sourceLinks()[0].template get().index(); + + auto [a, b] = measHitMap.equal_range(measId); + for (auto it = a; it != b; ++it) { + const auto& hit = *simhits.nth(it->second); + + tracks[hit.particleId()].push_back({i, hit.index()}); + } + } + + // Collect edges for truth graph and target graph + std::vector truthGraph; + + for (auto& [pid, track] : tracks) { + // Sort by hit index, so the edges are connected correctly + std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) { + return a.hitIndex < b.hitIndex; + }); + + auto found = particles.find(pid); + if (found == particles.end()) { + ACTS_WARNING("Did not find " << pid << ", skip track"); + continue; + } + + for (auto i = 0ul; i < track.size() - 1; ++i) { + if (found->transverseMomentum() > m_cfg.targetMinPT && + track.size() >= m_cfg.targetMinSize) { + truthGraph.push_back(track[i].spacePointIndex); + truthGraph.push_back(track[i + 1].spacePointIndex); + } + } + } + + return truthGraph; +} + +ProcessCode TruthGraphBuilder::execute( + const ActsExamples::AlgorithmContext& ctx) const { + // Read input data + const auto& spacepoints = m_inputSpacePoints(ctx); + const auto& particles = m_inputParticles(ctx); + + auto edges = (m_inputMeasParticlesMap.isInitialized()) + ? buildFromMeasurements(spacepoints, particles, + m_inputMeasParticlesMap(ctx)) + : buildFromSimhits(spacepoints, m_inputMeasSimhitMap(ctx), + m_inputSimhits(ctx), particles); + + ACTS_DEBUG("Truth track edges: " << edges.size() / 2); + + Graph g; + g.edges = std::move(edges); + + m_outputGraph(ctx, std::move(g)); + + return ProcessCode::SUCCESS; +} + +} // namespace ActsExamples diff --git a/Examples/Framework/include/ActsExamples/EventData/Graph.hpp b/Examples/Framework/include/ActsExamples/EventData/Graph.hpp new file mode 100644 index 00000000000..5014f6548c4 --- /dev/null +++ b/Examples/Framework/include/ActsExamples/EventData/Graph.hpp @@ -0,0 +1,32 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2024 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include +#include + +namespace ActsExamples { + +/// Lightweight graph representation for GNN debugging +/// +struct Graph { + /// The edges-vector contains flattened edge-pairs. Usually, the indices + /// refer to a spacepoint collection. + /// + /// Structure: [ source0, dest0, source1, dest1, ..., sourceN, destN ] + std::vector edges; + + /// The weight-vector should have half the size of the edges-vector (or + /// be empty if missing). + /// + /// Structure: [ weight0, weight1, ..., weightN ] + std::vector weights; +}; + +} // namespace ActsExamples diff --git a/Examples/Io/Csv/CMakeLists.txt b/Examples/Io/Csv/CMakeLists.txt index a2479cc9206..7d367d68892 100644 --- a/Examples/Io/Csv/CMakeLists.txt +++ b/Examples/Io/Csv/CMakeLists.txt @@ -19,6 +19,7 @@ add_library( src/CsvProtoTrackWriter.cpp src/CsvSpacePointWriter.cpp src/CsvExaTrkXGraphWriter.cpp + src/CsvExaTrkXGraphReader.cpp src/CsvBFieldWriter.cpp) target_include_directories( ActsExamplesIoCsv diff --git a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp new file mode 100644 index 00000000000..4b41d57a448 --- /dev/null +++ b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp @@ -0,0 +1,64 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2022 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/Graph.hpp" +#include "ActsExamples/Framework/DataHandle.hpp" +#include "ActsExamples/Framework/IReader.hpp" +#include "ActsExamples/Framework/ProcessCode.hpp" + +#include +#include +#include +#include +#include + +namespace ActsExamples { +struct AlgorithmContext; + +class CsvExaTrkXGraphReader final : public IReader { + public: + struct Config { + /// Where to read input files from. + std::string inputDir; + /// Input filename stem. + std::string inputStem; + /// Which vs collection to read into. + std::string outputGraph; + }; + + /// Construct the particle reader. + /// + /// @param config is the configuration object + /// @param level is the logging level + CsvExaTrkXGraphReader(const Config& config, Acts::Logging::Level level); + + /// Return the available events range. + std::pair availableEvents() const override; + + /// Read out data from the input stream. + ProcessCode read(const ActsExamples::AlgorithmContext& ctx) override; + + /// Return the name of the component + std::string name() const override { return "CsvExaTrkXGraphReader"; } + + /// Readonly access to the config + const Config& config() const { return m_cfg; } + + private: + Config m_cfg; + std::pair m_eventsRange; + std::unique_ptr m_logger; + + WriteDataHandle m_outputGraph{this, "OutputGraph"}; + const Acts::Logger& logger() const { return *m_logger; } +}; + +} // namespace ActsExamples diff --git a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp index eaf88c97a1d..15a2f7bb182 100644 --- a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp +++ b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp @@ -9,6 +9,7 @@ #pragma once #include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/Graph.hpp" #include "ActsExamples/Framework/ProcessCode.hpp" #include "ActsExamples/Framework/WriterT.hpp" #include "ActsExamples/Utilities/Paths.hpp" @@ -20,8 +21,7 @@ namespace ActsExamples { struct AlgorithmContext; -class CsvExaTrkXGraphWriter final - : public WriterT, std::vector>> { +class CsvExaTrkXGraphWriter final : public WriterT { public: struct Config { /// Which simulated (truth) hits collection to use. @@ -46,9 +46,7 @@ class CsvExaTrkXGraphWriter final /// /// @param[in] ctx is the algorithm context /// @param[in] simHits are the simhits to be written - ProcessCode writeT(const AlgorithmContext& ctx, - const std::pair, - std::vector>& graph) override; + ProcessCode writeT(const AlgorithmContext& ctx, const Graph& graph) override; private: Config m_cfg; diff --git a/Examples/Io/Csv/src/CsvExaTrkXGraphReader.cpp b/Examples/Io/Csv/src/CsvExaTrkXGraphReader.cpp new file mode 100644 index 00000000000..81eaad7eb51 --- /dev/null +++ b/Examples/Io/Csv/src/CsvExaTrkXGraphReader.cpp @@ -0,0 +1,73 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2017 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp" + +#include "Acts/Definitions/PdgParticle.hpp" +#include "Acts/Definitions/Units.hpp" +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/SimParticle.hpp" +#include "ActsExamples/Framework/AlgorithmContext.hpp" +#include "ActsExamples/Utilities/Paths.hpp" +#include "ActsFatras/EventData/Barcode.hpp" +#include "ActsFatras/EventData/Particle.hpp" +#include "ActsFatras/EventData/ProcessType.hpp" +#include + +#include +#include +#include +#include + +#include + +#include "CsvOutputData.hpp" + +namespace ActsExamples { + +CsvExaTrkXGraphReader::CsvExaTrkXGraphReader(const Config& config, + Acts::Logging::Level level) + : m_cfg(config), + m_eventsRange( + determineEventFilesRange(m_cfg.inputDir, m_cfg.inputStem + ".csv")), + m_logger(Acts::getDefaultLogger("CsvExaTrkXGraphReader", level)) { + if (m_cfg.inputStem.empty()) { + throw std::invalid_argument("Missing input filename stem"); + } + + m_outputGraph.initialize(m_cfg.outputGraph); +} + +std::pair CsvExaTrkXGraphReader::availableEvents() + const { + return m_eventsRange; +} + +ProcessCode CsvExaTrkXGraphReader::read(const AlgorithmContext& ctx) { + SimParticleContainer::sequence_type unordered; + + auto path = perEventFilepath(m_cfg.inputDir, m_cfg.inputStem + ".csv", + ctx.eventNumber); + // vt and m are an optional columns + dfe::NamedTupleCsvReader reader(path, {"vt", "m"}); + GraphData data; + + Graph g; + + while (reader.read(data)) { + g.edges.push_back(data.edge0); + g.edges.push_back(data.edge1); + g.weights.push_back(data.weight); + } + + m_outputGraph(ctx, std::move(g)); + + return ProcessCode::SUCCESS; +} + +} // namespace ActsExamples diff --git a/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp index fd7e0da98c6..49c394b7ed2 100644 --- a/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp +++ b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp @@ -21,12 +21,7 @@ #include #include -struct GraphData { - std::int64_t edge0; - std::int64_t edge1; - float weight; - DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight); -}; +#include "CsvOutputData.hpp" ActsExamples::CsvExaTrkXGraphWriter::CsvExaTrkXGraphWriter( const ActsExamples::CsvExaTrkXGraphWriter::Config& config, @@ -35,20 +30,26 @@ ActsExamples::CsvExaTrkXGraphWriter::CsvExaTrkXGraphWriter( m_cfg(config) {} ActsExamples::ProcessCode ActsExamples::CsvExaTrkXGraphWriter::writeT( - const ActsExamples::AlgorithmContext& ctx, - const std::pair, std::vector>& graph) { + const ActsExamples::AlgorithmContext& ctx, const Graph& graph) { + assert(graph.weights.empty() || + (graph.edges.size() / 2 == graph.weights.size())); + assert(graph.edges.size() % 2 == 0); + + if (graph.weights.empty()) { + ACTS_DEBUG("No weights provide, write default value of 1"); + } + std::string path = perEventFilepath( m_cfg.outputDir, m_cfg.outputStem + ".csv", ctx.eventNumber); dfe::NamedTupleCsvWriter writer(path); - const auto& [edges, weights] = graph; - - for (auto i = 0ul; i < weights.size(); ++i) { + const auto nEdges = graph.edges.size() / 2; + for (auto i = 0ul; i < nEdges; ++i) { GraphData edge{}; - edge.edge0 = edges[2 * i]; - edge.edge1 = edges[2 * i + 1]; - edge.weight = weights[i]; + edge.edge0 = graph.edges[2 * i]; + edge.edge1 = graph.edges[2 * i + 1]; + edge.weight = graph.weights.empty() ? 1.f : graph.weights[i]; writer.append(edge); } diff --git a/Examples/Io/Csv/src/CsvOutputData.hpp b/Examples/Io/Csv/src/CsvOutputData.hpp index 925b8ff7719..5cff2f69fa1 100644 --- a/Examples/Io/Csv/src/CsvOutputData.hpp +++ b/Examples/Io/Csv/src/CsvOutputData.hpp @@ -349,4 +349,11 @@ struct ProtoTrackData { DFE_NAMEDTUPLE(ProtoTrackData, trackId, measurementId, x, y, z); }; +struct GraphData { + std::int64_t edge0 = 0; + std::int64_t edge1 = 0; + float weight = 0.0; + DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight); +}; + } // namespace ActsExamples diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index 90c1002f00c..2d9d0341362 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -18,6 +18,7 @@ #include "ActsExamples/TrackFindingExaTrkX/PrototracksToParameters.hpp" #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp" #include "ActsExamples/TrackFindingExaTrkX/TrackFindingFromPrototrackAlgorithm.hpp" +#include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp" #include @@ -167,6 +168,12 @@ void addExaTrkXTrackFinding(Context &ctx) { } #endif + ACTS_PYTHON_DECLARE_ALGORITHM( + ActsExamples::TruthGraphBuilder, mex, "TruthGraphBuilder", + inputSpacePoints, inputSimHits, inputParticles, + inputMeasurementSimHitsMap, inputMeasurementParticlesMap, outputGraph, + targetMinPT, targetMinSize, uniqueModules); + py::enum_(mex, "NodeFeature") .value("R", TrackFindingAlgorithmExaTrkX::NodeFeature::eR) .value("Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::ePhi) @@ -191,12 +198,12 @@ void addExaTrkXTrackFinding(Context &ctx) { .value("Cluster2Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster2Eta); - ACTS_PYTHON_DECLARE_ALGORITHM( - ActsExamples::TrackFindingAlgorithmExaTrkX, mex, - "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits, - inputParticles, inputClusters, inputMeasurementSimhitsMap, - outputProtoTracks, outputGraph, graphConstructor, edgeClassifiers, - trackBuilder, nodeFeatures, featureScales, filterShortTracks); + ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::TrackFindingAlgorithmExaTrkX, mex, + "TrackFindingAlgorithmExaTrkX", + inputSpacePoints, inputClusters, + inputTruthGraph, outputProtoTracks, outputGraph, + graphConstructor, edgeClassifiers, trackBuilder, + nodeFeatures, featureScales, filterShortTracks); { auto cls = diff --git a/Examples/Python/src/Input.cpp b/Examples/Python/src/Input.cpp index ddbd33ede4f..2036e51674a 100644 --- a/Examples/Python/src/Input.cpp +++ b/Examples/Python/src/Input.cpp @@ -9,6 +9,7 @@ #include "Acts/Plugins/Python/Utilities.hpp" #include "ActsExamples/EventData/Cluster.hpp" #include "ActsExamples/Io/Csv/CsvDriftCircleReader.hpp" +#include "ActsExamples/Io/Csv/CsvExaTrkXGraphReader.hpp" #include "ActsExamples/Io/Csv/CsvMeasurementReader.hpp" #include "ActsExamples/Io/Csv/CsvMuonSimHitReader.hpp" #include "ActsExamples/Io/Csv/CsvParticleReader.hpp" @@ -97,6 +98,10 @@ void addInput(Context& ctx) { ACTS_PYTHON_DECLARE_READER(ActsExamples::RootSimHitReader, mex, "RootSimHitReader", treeName, filePath, outputSimHits); + + ACTS_PYTHON_DECLARE_READER(ActsExamples::CsvExaTrkXGraphReader, mex, + "CsvExaTrkXGraphReader", inputDir, inputStem, + outputGraph); } } // namespace Acts::Python