Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Truth graph metrics for Exa.TrkX with separate algorithm #3354

1 change: 1 addition & 0 deletions Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_library(
src/TrackFindingAlgorithmExaTrkX.cpp
src/PrototracksToParameters.cpp
src/TrackFindingFromPrototrackAlgorithm.cpp
src/TruthGraphBuilder.cpp
)

target_include_directories(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
#include "Acts/Plugins/ExaTrkX/Stages.hpp"
#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/Graph.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimHit.hpp"
#include "ActsExamples/EventData/SimParticle.hpp"
#include "ActsExamples/EventData/SimSpacePoint.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"
#include "ActsExamples/TrackFindingExaTrkX/TruthGraphBuilder.hpp"

#include <mutex>
#include <string>
Expand Down Expand Up @@ -57,15 +59,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
std::string inputSpacePoints;
/// Input cluster information (Optional).
std::string inputClusters;

/// Input simhits (Optional).
std::string inputSimHits;
/// Input measurement simhit map (Optional).
std::string inputParticles;
/// Input measurement simhit map (Optional).
std::string inputMeasurementSimhitsMap;

/// Output protoTracks collection.
/// Input truth graph (Optional).
std::string inputTruthGraph;
/// Output prototracks
std::string outputProtoTracks;

/// Output graph (optional)
Expand All @@ -86,10 +82,6 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {

/// Remove track candidates with 2 or less hits
bool filterShortTracks = false;

/// Target graph properties
std::size_t targetMinHits = 3;
double targetMinPT = 500 * Acts::UnitConstants::MeV;
};

/// Constructor of the track finding algorithm
Expand Down Expand Up @@ -132,16 +124,10 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
"InputSpacePoints"};
ReadDataHandle<ClusterContainer> m_inputClusters{this, "InputClusters"};

ReadDataHandle<Graph> m_inputTruthGraph{this, "InputTruthGraph"};
WriteDataHandle<ProtoTrackContainer> m_outputProtoTracks{this,
"OutputProtoTracks"};
WriteDataHandle<Acts::TorchGraphStoreHook::Graph> m_outputGraph{
this, "OutputGraph"};

// for truth graph
ReadDataHandle<SimHitContainer> m_inputSimHits{this, "InputSimHits"};
ReadDataHandle<SimParticleContainer> m_inputParticles{this, "InputParticles"};
ReadDataHandle<IndexMultimap<Index>> m_inputMeasurementMap{
this, "InputMeasurementMap"};
WriteDataHandle<Graph> m_outputGraph{this, "OutputGraph"};
};

} // namespace ActsExamples
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// This file is part of the Acts project.
//
// Copyright (C) 2022 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Definitions/Units.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/Graph.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimHit.hpp"
#include "ActsExamples/EventData/SimParticle.hpp"
#include "ActsExamples/EventData/SimSpacePoint.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"

namespace ActsExamples {

/// Algorithm to create a truth graph for computing truth metrics
/// Requires spacepoints and particles collection
/// Either provide a measurements particles map, or a measurement simhit map +
/// simhits
class TruthGraphBuilder final : public IAlgorithm {
public:
struct Config {
/// Input spacepoint collection
std::string inputSpacePoints;
/// Input particles collection
std::string inputParticles;
/// Input measurement particles map (Optional).
std::string inputMeasurementParticlesMap;
/// Input simhits (Optional).
std::string inputSimHits;
/// Input measurement simhit map (Optional).
std::string inputMeasurementSimHitsMap;
/// Output truth graph
std::string outputGraph;

double targetMinPT = 0.5;
std::size_t targetMinSize = 3;

/// Only allow one hit per track & module
bool uniqueModules = false;
};

TruthGraphBuilder(Config cfg, Acts::Logging::Level lvl);

~TruthGraphBuilder() override = default;

ActsExamples::ProcessCode execute(
const ActsExamples::AlgorithmContext& ctx) const final;

const Config& config() const { return m_cfg; }

private:
Config m_cfg;

std::vector<std::int64_t> buildFromMeasurements(
const SimSpacePointContainer& spacepoints,
const SimParticleContainer& particles,
const IndexMultimap<ActsFatras::Barcode>& measPartMap) const;

std::vector<std::int64_t> buildFromSimhits(
const SimSpacePointContainer& spacepoints,
const IndexMultimap<Index>& measHitMap, const SimHitContainer& simhits,
const SimParticleContainer& particles) const;

ReadDataHandle<SimSpacePointContainer> m_inputSpacePoints{this,
"InputSpacePoints"};
ReadDataHandle<SimParticleContainer> m_inputParticles{this, "InputParticles"};
ReadDataHandle<IndexMultimap<ActsFatras::Barcode>> m_inputMeasParticlesMap{
this, "InputMeasParticlesMap"};
ReadDataHandle<SimHitContainer> m_inputSimhits{this, "InputSimhits"};
ReadDataHandle<IndexMultimap<Index>> m_inputMeasSimhitMap{
this, "InputMeasSimhitMap"};

WriteDataHandle<Graph> m_outputGraph{this, "OutputGraph"};
};
} // namespace ActsExamples
Original file line number Diff line number Diff line change
Expand Up @@ -25,97 +25,16 @@ using namespace Acts::UnitLiterals;

namespace {

class ExamplesEdmHook : public Acts::ExaTrkXHook {
double m_targetPT = 0.5_GeV;
std::size_t m_targetSize = 3;

std::unique_ptr<const Acts::Logger> m_logger;
std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_truthGraphHook;
std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_targetGraphHook;
std::unique_ptr<Acts::TorchGraphStoreHook> m_graphStoreHook;

const Acts::Logger& logger() const { return *m_logger; }

struct HitInfo {
std::size_t spacePointIndex;
std::int32_t hitIndex;
};

public:
ExamplesEdmHook(const SimSpacePointContainer& spacepoints,
const IndexMultimap<Index>& measHitMap,
const SimHitContainer& simhits,
const SimParticleContainer& particles,
std::size_t targetMinHits, double targetMinPT,
const Acts::Logger& logger)
: m_targetPT(targetMinPT),
m_targetSize(targetMinHits),
m_logger(logger.clone("MetricsHook")) {
// Associate tracks to graph, collect momentum
std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;

for (auto i = 0ul; i < spacepoints.size(); ++i) {
const auto measId = spacepoints[i]
.sourceLinks()[0]
.template get<IndexSourceLink>()
.index();

auto [a, b] = measHitMap.equal_range(measId);
for (auto it = a; it != b; ++it) {
const auto& hit = *simhits.nth(it->second);

tracks[hit.particleId()].push_back({i, hit.index()});
}
}

// Collect edges for truth graph and target graph
std::vector<std::int64_t> truthGraph;
std::vector<std::int64_t> targetGraph;

for (auto& [pid, track] : tracks) {
// Sort by hit index, so the edges are connected correctly
std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) {
return a.hitIndex < b.hitIndex;
});

auto found = particles.find(pid);
if (found == particles.end()) {
ACTS_WARNING("Did not find " << pid << ", skip track");
continue;
}

for (auto i = 0ul; i < track.size() - 1; ++i) {
truthGraph.push_back(track[i].spacePointIndex);
truthGraph.push_back(track[i + 1].spacePointIndex);

if (found->transverseMomentum() > m_targetPT &&
track.size() >= m_targetSize) {
targetGraph.push_back(track[i].spacePointIndex);
targetGraph.push_back(track[i + 1].spacePointIndex);
}
}
}
struct LoopHook : public Acts::ExaTrkXHook {
std::vector<Acts::ExaTrkXHook*> hooks;

m_truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
truthGraph, logger.clone());
m_targetGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
targetGraph, logger.clone());
m_graphStoreHook = std::make_unique<Acts::TorchGraphStoreHook>();
}

~ExamplesEdmHook() {}

auto storedGraph() const { return m_graphStoreHook->storedGraph(); }
~LoopHook() {}

void operator()(const std::any& nodes, const std::any& edges,
const std::any& weights) const override {
ACTS_INFO("Metrics for total graph:");
(*m_truthGraphHook)(nodes, edges, weights);
ACTS_INFO("Metrics for target graph (pT > "
<< m_targetPT / Acts::UnitConstants::GeV
<< " GeV, nHits >= " << m_targetSize << "):");
(*m_targetGraphHook)(nodes, edges, weights);
(*m_graphStoreHook)(nodes, edges, weights);
for (auto hook : hooks) {
(*hook)(nodes, edges, weights);
}
}
};

Expand Down Expand Up @@ -164,10 +83,7 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(
m_inputClusters.maybeInitialize(m_cfg.inputClusters);
m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);

m_inputSimHits.maybeInitialize(m_cfg.inputSimHits);
m_inputParticles.maybeInitialize(m_cfg.inputParticles);
m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap);

m_inputTruthGraph.maybeInitialize(m_cfg.inputTruthGraph);
m_outputGraph.maybeInitialize(m_cfg.outputGraph);

// reserve space for timing
Expand Down Expand Up @@ -200,17 +116,25 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(

ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
const ActsExamples::AlgorithmContext& ctx) const {
// Read input data
auto spacepoints = m_inputSpacePoints(ctx);
// Setup hooks
LoopHook hook;

std::unique_ptr<Acts::TorchTruthGraphMetricsHook> truthGraphHook;
if (m_inputTruthGraph.isInitialized()) {
truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
m_inputTruthGraph(ctx).edges, this->logger().clone());
hook.hooks.push_back(&*truthGraphHook);
}

auto hook = std::make_unique<Acts::ExaTrkXHook>();
if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) {
hook = std::make_unique<ExamplesEdmHook>(
spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx),
m_inputParticles(ctx), m_cfg.targetMinHits, m_cfg.targetMinPT,
logger());
std::unique_ptr<Acts::TorchGraphStoreHook> graphStoreHook;
if (m_outputGraph.isInitialized()) {
graphStoreHook = std::make_unique<Acts::TorchGraphStoreHook>();
hook.hooks.push_back(&*graphStoreHook);
}

// Read input data
auto spacepoints = m_inputSpacePoints(ctx);

std::optional<ClusterContainer> clusters;
if (m_inputClusters.isInitialized()) {
clusters = m_inputClusters(ctx);
Expand Down Expand Up @@ -288,7 +212,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
std::lock_guard<std::mutex> lock(m_mutex);

Acts::ExaTrkXTiming timing;
auto res = m_pipeline.run(features, spacepointIDs, *hook, &timing);
auto res = m_pipeline.run(features, spacepointIDs, hook, &timing);

m_timing.graphBuildingTime(timing.graphBuildingTime.count());

Expand Down Expand Up @@ -329,13 +253,12 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
m_outputProtoTracks(ctx, std::move(protoTracks));

if (auto dhook = dynamic_cast<ExamplesEdmHook*>(&*hook);
dhook && m_outputGraph.isInitialized()) {
auto graph = dhook->storedGraph();
if (m_outputGraph.isInitialized()) {
auto graph = graphStoreHook->storedGraph();
std::transform(
graph.first.begin(), graph.first.end(), graph.first.begin(),
[&](const auto& a) -> std::int64_t { return spacepointIDs.at(a); });
m_outputGraph(ctx, std::move(graph));
m_outputGraph(ctx, {graph.first, graph.second});
}

return ActsExamples::ProcessCode::SUCCESS;
Expand Down
Loading
Loading