Skip to content

Commit

Permalink
refactor: Truth graph metrics for Exa.TrkX with separate algorithm (#…
Browse files Browse the repository at this point in the history
…3354)

Refactors code for truth graph building into separate algorithm, also improve reader/writer infrastructure for graphs.
  • Loading branch information
benjaminhuth authored Jul 11, 2024
1 parent b5069e1 commit 2413095
Show file tree
Hide file tree
Showing 14 changed files with 526 additions and 150 deletions.
1 change: 1 addition & 0 deletions Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_library(
src/TrackFindingAlgorithmExaTrkX.cpp
src/PrototracksToParameters.cpp
src/TrackFindingFromPrototrackAlgorithm.cpp
src/TruthGraphBuilder.cpp
)

target_include_directories(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
#include "Acts/Plugins/ExaTrkX/Stages.hpp"
#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/Graph.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimHit.hpp"
#include "ActsExamples/EventData/SimParticle.hpp"
#include "ActsExamples/EventData/SimSpacePoint.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"
#include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp"

#include <mutex>
#include <string>
Expand Down Expand Up @@ -57,15 +59,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
std::string inputSpacePoints;
/// Input cluster information (Optional).
std::string inputClusters;

/// Input simhits (Optional).
std::string inputSimHits;
/// Input measurement simhit map (Optional).
std::string inputParticles;
/// Input measurement simhit map (Optional).
std::string inputMeasurementSimhitsMap;

/// Output protoTracks collection.
/// Input truth graph (Optional).
std::string inputTruthGraph;
/// Output prototracks
std::string outputProtoTracks;

/// Output graph (optional)
Expand All @@ -86,10 +82,6 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {

/// Remove track candidates with 2 or less hits
bool filterShortTracks = false;

/// Target graph properties
std::size_t targetMinHits = 3;
double targetMinPT = 500 * Acts::UnitConstants::MeV;
};

/// Constructor of the track finding algorithm
Expand Down Expand Up @@ -132,16 +124,10 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
"InputSpacePoints"};
ReadDataHandle<ClusterContainer> m_inputClusters{this, "InputClusters"};

ReadDataHandle<Graph> m_inputTruthGraph{this, "InputTruthGraph"};
WriteDataHandle<ProtoTrackContainer> m_outputProtoTracks{this,
"OutputProtoTracks"};
WriteDataHandle<Acts::TorchGraphStoreHook::Graph> m_outputGraph{
this, "OutputGraph"};

// for truth graph
ReadDataHandle<SimHitContainer> m_inputSimHits{this, "InputSimHits"};
ReadDataHandle<SimParticleContainer> m_inputParticles{this, "InputParticles"};
ReadDataHandle<IndexMultimap<Index>> m_inputMeasurementMap{
this, "InputMeasurementMap"};
WriteDataHandle<Graph> m_outputGraph{this, "OutputGraph"};
};

} // namespace ActsExamples
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// This file is part of the Acts project.
//
// Copyright (C) 2022 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Definitions/Units.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/Graph.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimHit.hpp"
#include "ActsExamples/EventData/SimParticle.hpp"
#include "ActsExamples/EventData/SimSpacePoint.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"

namespace ActsExamples {

/// Algorithm to create a truth graph for computing truth metrics
/// Requires spacepoints and particles collection
/// Either provide a measurements particles map, or a measurement simhit map +
/// simhits
class TruthGraphBuilder final : public IAlgorithm {
public:
struct Config {
/// Input spacepoint collection
std::string inputSpacePoints;
/// Input particles collection
std::string inputParticles;
/// Input measurement particles map (Optional).
std::string inputMeasurementParticlesMap;
/// Input simhits (Optional).
std::string inputSimHits;
/// Input measurement simhit map (Optional).
std::string inputMeasurementSimHitsMap;
/// Output truth graph
std::string outputGraph;

double targetMinPT = 0.5;
std::size_t targetMinSize = 3;

/// Only allow one hit per track & module
bool uniqueModules = false;
};

TruthGraphBuilder(Config cfg, Acts::Logging::Level lvl);

~TruthGraphBuilder() override = default;

ActsExamples::ProcessCode execute(
const ActsExamples::AlgorithmContext& ctx) const final;

const Config& config() const { return m_cfg; }

private:
Config m_cfg;

std::vector<std::int64_t> buildFromMeasurements(
const SimSpacePointContainer& spacepoints,
const SimParticleContainer& particles,
const IndexMultimap<ActsFatras::Barcode>& measPartMap) const;

std::vector<std::int64_t> buildFromSimhits(
const SimSpacePointContainer& spacepoints,
const IndexMultimap<Index>& measHitMap, const SimHitContainer& simhits,
const SimParticleContainer& particles) const;

ReadDataHandle<SimSpacePointContainer> m_inputSpacePoints{this,
"InputSpacePoints"};
ReadDataHandle<SimParticleContainer> m_inputParticles{this, "InputParticles"};
ReadDataHandle<IndexMultimap<ActsFatras::Barcode>> m_inputMeasParticlesMap{
this, "InputMeasParticlesMap"};
ReadDataHandle<SimHitContainer> m_inputSimhits{this, "InputSimhits"};
ReadDataHandle<IndexMultimap<Index>> m_inputMeasSimhitMap{
this, "InputMeasSimhitMap"};

WriteDataHandle<Graph> m_outputGraph{this, "OutputGraph"};
};
} // namespace ActsExamples
Original file line number Diff line number Diff line change
Expand Up @@ -25,97 +25,16 @@ using namespace Acts::UnitLiterals;

namespace {

class ExamplesEdmHook : public Acts::ExaTrkXHook {
double m_targetPT = 0.5_GeV;
std::size_t m_targetSize = 3;

std::unique_ptr<const Acts::Logger> m_logger;
std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_truthGraphHook;
std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_targetGraphHook;
std::unique_ptr<Acts::TorchGraphStoreHook> m_graphStoreHook;

const Acts::Logger& logger() const { return *m_logger; }

struct HitInfo {
std::size_t spacePointIndex;
std::int32_t hitIndex;
};

public:
ExamplesEdmHook(const SimSpacePointContainer& spacepoints,
const IndexMultimap<Index>& measHitMap,
const SimHitContainer& simhits,
const SimParticleContainer& particles,
std::size_t targetMinHits, double targetMinPT,
const Acts::Logger& logger)
: m_targetPT(targetMinPT),
m_targetSize(targetMinHits),
m_logger(logger.clone("MetricsHook")) {
// Associate tracks to graph, collect momentum
std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;

for (auto i = 0ul; i < spacepoints.size(); ++i) {
const auto measId = spacepoints[i]
.sourceLinks()[0]
.template get<IndexSourceLink>()
.index();

auto [a, b] = measHitMap.equal_range(measId);
for (auto it = a; it != b; ++it) {
const auto& hit = *simhits.nth(it->second);

tracks[hit.particleId()].push_back({i, hit.index()});
}
}

// Collect edges for truth graph and target graph
std::vector<std::int64_t> truthGraph;
std::vector<std::int64_t> targetGraph;

for (auto& [pid, track] : tracks) {
// Sort by hit index, so the edges are connected correctly
std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) {
return a.hitIndex < b.hitIndex;
});

auto found = particles.find(pid);
if (found == particles.end()) {
ACTS_WARNING("Did not find " << pid << ", skip track");
continue;
}

for (auto i = 0ul; i < track.size() - 1; ++i) {
truthGraph.push_back(track[i].spacePointIndex);
truthGraph.push_back(track[i + 1].spacePointIndex);

if (found->transverseMomentum() > m_targetPT &&
track.size() >= m_targetSize) {
targetGraph.push_back(track[i].spacePointIndex);
targetGraph.push_back(track[i + 1].spacePointIndex);
}
}
}
struct LoopHook : public Acts::ExaTrkXHook {
std::vector<Acts::ExaTrkXHook*> hooks;

m_truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
truthGraph, logger.clone());
m_targetGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
targetGraph, logger.clone());
m_graphStoreHook = std::make_unique<Acts::TorchGraphStoreHook>();
}

~ExamplesEdmHook() {}

auto storedGraph() const { return m_graphStoreHook->storedGraph(); }
~LoopHook() {}

void operator()(const std::any& nodes, const std::any& edges,
const std::any& weights) const override {
ACTS_INFO("Metrics for total graph:");
(*m_truthGraphHook)(nodes, edges, weights);
ACTS_INFO("Metrics for target graph (pT > "
<< m_targetPT / Acts::UnitConstants::GeV
<< " GeV, nHits >= " << m_targetSize << "):");
(*m_targetGraphHook)(nodes, edges, weights);
(*m_graphStoreHook)(nodes, edges, weights);
for (auto hook : hooks) {
(*hook)(nodes, edges, weights);
}
}
};

Expand Down Expand Up @@ -164,10 +83,7 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(
m_inputClusters.maybeInitialize(m_cfg.inputClusters);
m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);

m_inputSimHits.maybeInitialize(m_cfg.inputSimHits);
m_inputParticles.maybeInitialize(m_cfg.inputParticles);
m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap);

m_inputTruthGraph.maybeInitialize(m_cfg.inputTruthGraph);
m_outputGraph.maybeInitialize(m_cfg.outputGraph);

// reserve space for timing
Expand Down Expand Up @@ -200,17 +116,25 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(

ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
const ActsExamples::AlgorithmContext& ctx) const {
// Read input data
auto spacepoints = m_inputSpacePoints(ctx);
// Setup hooks
LoopHook hook;

std::unique_ptr<Acts::TorchTruthGraphMetricsHook> truthGraphHook;
if (m_inputTruthGraph.isInitialized()) {
truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
m_inputTruthGraph(ctx).edges, this->logger().clone());
hook.hooks.push_back(&*truthGraphHook);
}

auto hook = std::make_unique<Acts::ExaTrkXHook>();
if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) {
hook = std::make_unique<ExamplesEdmHook>(
spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx),
m_inputParticles(ctx), m_cfg.targetMinHits, m_cfg.targetMinPT,
logger());
std::unique_ptr<Acts::TorchGraphStoreHook> graphStoreHook;
if (m_outputGraph.isInitialized()) {
graphStoreHook = std::make_unique<Acts::TorchGraphStoreHook>();
hook.hooks.push_back(&*graphStoreHook);
}

// Read input data
auto spacepoints = m_inputSpacePoints(ctx);

std::optional<ClusterContainer> clusters;
if (m_inputClusters.isInitialized()) {
clusters = m_inputClusters(ctx);
Expand Down Expand Up @@ -288,7 +212,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
std::lock_guard<std::mutex> lock(m_mutex);

Acts::ExaTrkXTiming timing;
auto res = m_pipeline.run(features, spacepointIDs, *hook, &timing);
auto res = m_pipeline.run(features, spacepointIDs, hook, &timing);

m_timing.graphBuildingTime(timing.graphBuildingTime.count());

Expand Down Expand Up @@ -329,13 +253,12 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
m_outputProtoTracks(ctx, std::move(protoTracks));

if (auto dhook = dynamic_cast<ExamplesEdmHook*>(&*hook);
dhook && m_outputGraph.isInitialized()) {
auto graph = dhook->storedGraph();
if (m_outputGraph.isInitialized()) {
auto graph = graphStoreHook->storedGraph();
std::transform(
graph.first.begin(), graph.first.end(), graph.first.begin(),
[&](const auto& a) -> std::int64_t { return spacepointIDs.at(a); });
m_outputGraph(ctx, std::move(graph));
m_outputGraph(ctx, {graph.first, graph.second});
}

return ActsExamples::ProcessCode::SUCCESS;
Expand Down
Loading

0 comments on commit 2413095

Please sign in to comment.