Skip to content

Commit

Permalink
feat: Update GNN plugin (#3876)
Browse files Browse the repository at this point in the history
This updates the GNN plugin from my dev branch. Contains:
* Enhance interfaces of pipeline stages to handle edge features and module-based graph construction
* Add a lot more hit-features
* Refactor feature building into separate function
* Change feature selection in for pipeline stages
* More detailed time measurements
* Start to refactor ONNX edge classification
  • Loading branch information
benjaminhuth authored Nov 26, 2024
1 parent 4656a8d commit a0aed8c
Show file tree
Hide file tree
Showing 18 changed files with 659 additions and 256 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "Acts/Plugins/ExaTrkX/Stages.hpp"
#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/GeometryContainers.hpp"
#include "ActsExamples/EventData/Graph.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimHit.hpp"
Expand All @@ -34,24 +35,60 @@ namespace ActsExamples {
class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
public:
enum class NodeFeature {
// SP features
eR,
ePhi,
eX,
eY,
eZ,
eEta,
// Single cluster features
eCellCount,
eCellSum,
eClusterX,
eClusterY,
eChargeSum,
eClusterLoc0,
eClusterLoc1,
// Cluster 1 features
eCluster1X,
eCluster1Y,
eCluster1Z,
eCluster1R,
eCluster2R,
eCluster1Phi,
eCluster2Phi,
eCluster1Z,
eCluster2Z,
eCluster1Eta,
eCellCount1,
eChargeSum1,
eLocEta1,
eLocPhi1,
eLocDir01,
eLocDir11,
eLocDir21,
eLengthDir01,
eLengthDir11,
eLengthDir21,
eGlobEta1,
eGlobPhi1,
eEtaAngle1,
ePhiAngle1,
// Cluster 2 features
eCluster2X,
eCluster2Y,
eCluster2Z,
eCluster2R,
eCluster2Phi,
eCluster2Eta,
eCellCount2,
eChargeSum2,
eLocEta2,
eLocPhi2,
eLocDir02,
eLocDir12,
eLocDir22,
eLengthDir02,
eLengthDir12,
eLengthDir22,
eGlobEta2,
eGlobPhi2,
eEtaAngle2,
ePhiAngle2,
};

struct Config {
Expand All @@ -63,14 +100,16 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
std::string inputTruthGraph;
/// Output prototracks
std::string outputProtoTracks;

/// Output graph (optional)
std::string outputGraph;

/// Graph constructor
std::shared_ptr<Acts::GraphConstructionBase> graphConstructor;

/// List of edge classifiers
std::vector<std::shared_ptr<Acts::EdgeClassificationBase>> edgeClassifiers;

/// The track builder
std::shared_ptr<Acts::TrackBuildingBase> trackBuilder;

/// Node features
Expand All @@ -81,7 +120,10 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
std::vector<float> featureScales = {1.f, 1.f, 1.f};

/// Remove track candidates with 2 or less hits
bool filterShortTracks = false;
std::size_t minMeasurementsPerTrack = 3;

/// Optionally remap the geometry Ids that are put into the chain
std::shared_ptr<GeometryIdMapActsAthena> geometryIdMap;
};

/// Constructor of the track finding algorithm
Expand Down Expand Up @@ -111,13 +153,17 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
mutable std::mutex m_mutex;

using Accumulator = boost::accumulators::accumulator_set<
float, boost::accumulators::features<boost::accumulators::tag::mean,
boost::accumulators::tag::variance>>;
float,
boost::accumulators::features<
boost::accumulators::tag::mean, boost::accumulators::tag::variance,
boost::accumulators::tag::max, boost::accumulators::tag::min>>;

mutable struct {
Accumulator preprocessingTime;
Accumulator graphBuildingTime;
std::vector<Accumulator> classifierTimes;
Accumulator trackBuildingTime;
Accumulator postprocessingTime;
} m_timing;

ReadDataHandle<SimSpacePointContainer> m_inputSpacePoints{this,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
#include "ActsExamples/Framework/WhiteBoard.hpp"

#include <algorithm>
#include <chrono>
#include <numeric>

#include "createFeatures.hpp"

using namespace ActsExamples;
using namespace Acts::UnitLiterals;

Expand Down Expand Up @@ -87,8 +90,9 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(

// Check if we want cluster features but do not have them
const static std::array clFeatures = {
NodeFeature::eClusterX, NodeFeature::eClusterY, NodeFeature::eCellCount,
NodeFeature::eCellSum, NodeFeature::eCluster1R, NodeFeature::eCluster2R};
NodeFeature::eClusterLoc0, NodeFeature::eClusterLoc0,
NodeFeature::eCellCount, NodeFeature::eChargeSum,
NodeFeature::eCluster1R, NodeFeature::eCluster2R};

auto wantClFeatures = std::ranges::any_of(
m_cfg.nodeFeatures,
Expand All @@ -108,6 +112,10 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(

ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
const ActsExamples::AlgorithmContext& ctx) const {
using Clock = std::chrono::high_resolution_clock;
using Duration = std::chrono::duration<double, std::milli>;
auto t0 = Clock::now();

// Setup hooks
LoopHook hook;

Expand Down Expand Up @@ -139,10 +147,11 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
ACTS_DEBUG("Received " << numSpacepoints << " spacepoints");
ACTS_DEBUG("Construct " << numFeatures << " node features");

std::vector<float> features(numSpacepoints * numFeatures);
std::vector<int> spacepointIDs;
std::vector<std::uint64_t> moduleIds;

spacepointIDs.reserve(spacepoints.size());
moduleIds.reserve(spacepoints.size());

for (auto isp = 0ul; isp < numSpacepoints; ++isp) {
const auto& sp = spacepoints[isp];
Expand All @@ -157,54 +166,25 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
// to the pipeline
spacepointIDs.push_back(isp);

// This should be fine, because check in constructor
Cluster* cl1 = clusters ? &clusters->at(sl1.index()) : nullptr;
Cluster* cl2 = cl1;

if (sp.sourceLinks().size() == 2) {
const auto& sl2 = sp.sourceLinks()[1].template get<IndexSourceLink>();
cl2 = clusters ? &clusters->at(sl2.index()) : nullptr;
if (m_cfg.geometryIdMap != nullptr) {
moduleIds.push_back(m_cfg.geometryIdMap->right.at(sl1.geometryId()));
} else {
moduleIds.push_back(sl1.geometryId().value());
}
}

// I would prefer to use a std::span or boost::span here once available
float* f = features.data() + isp * numFeatures;

using NF = NodeFeature;

for (auto ift = 0ul; ift < numFeatures; ++ift) {
// clang-format off
switch(m_cfg.nodeFeatures[ift]) {
break; case NF::eR: f[ift] = std::hypot(sp.x(), sp.y());
break; case NF::ePhi: f[ift] = std::atan2(sp.y(), sp.x());
break; case NF::eZ: f[ift] = sp.z();
break; case NF::eX: f[ift] = sp.x();
break; case NF::eY: f[ift] = sp.y();
break; case NF::eEta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{sp.x(), sp.y(), sp.z()});
break; case NF::eClusterX: f[ift] = cl1->sizeLoc0;
break; case NF::eClusterY: f[ift] = cl1->sizeLoc1;
break; case NF::eCellSum: f[ift] = cl1->sumActivations();
break; case NF::eCellCount: f[ift] = cl1->channels.size();
break; case NF::eCluster1R: f[ift] = std::hypot(cl1->globalPosition[Acts::ePos0], cl1->globalPosition[Acts::ePos1]);
break; case NF::eCluster2R: f[ift] = std::hypot(cl2->globalPosition[Acts::ePos0], cl2->globalPosition[Acts::ePos1]);
break; case NF::eCluster1Phi: f[ift] = std::atan2(cl1->globalPosition[Acts::ePos1], cl1->globalPosition[Acts::ePos0]);
break; case NF::eCluster2Phi: f[ift] = std::atan2(cl2->globalPosition[Acts::ePos1], cl2->globalPosition[Acts::ePos0]);
break; case NF::eCluster1Z: f[ift] = cl1->globalPosition[Acts::ePos2];
break; case NF::eCluster2Z: f[ift] = cl2->globalPosition[Acts::ePos2];
break; case NF::eCluster1Eta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{cl1->globalPosition[Acts::ePos0], cl1->globalPosition[Acts::ePos1], cl1->globalPosition[Acts::ePos2]});
break; case NF::eCluster2Eta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{cl2->globalPosition[Acts::ePos0], cl2->globalPosition[Acts::ePos1], cl2->globalPosition[Acts::ePos2]});
}
// clang-format on
auto features = createFeatures(spacepoints, clusters, m_cfg.nodeFeatures,
m_cfg.featureScales);

f[ift] /= m_cfg.featureScales[ift];
}
}
auto t1 = Clock::now();

// Run the pipeline
const auto trackCandidates = [&]() {
std::lock_guard<std::mutex> lock(m_mutex);

Acts::ExaTrkXTiming timing;
auto res = m_pipeline.run(features, spacepointIDs, hook, &timing);
auto res =
m_pipeline.run(features, moduleIds, spacepointIDs, hook, &timing);

m_timing.graphBuildingTime(timing.graphBuildingTime.count());

Expand All @@ -219,6 +199,8 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
return res;
}();

auto t2 = Clock::now();

ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
<< " candidates");

Expand All @@ -228,20 +210,28 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(

int nShortTracks = 0;

for (auto& x : trackCandidates) {
if (m_cfg.filterShortTracks && x.size() < 3) {
/// TODO the whole conversion back to meas idxs should be pulled out of the
/// track trackBuilder
for (auto& candidate : trackCandidates) {
ProtoTrack onetrack;
onetrack.reserve(candidate.size());

for (auto i : candidate) {
for (const auto& sl : spacepoints[i].sourceLinks()) {
onetrack.push_back(sl.template get<IndexSourceLink>().index());
}
}

if (onetrack.size() < m_cfg.minMeasurementsPerTrack) {
nShortTracks++;
continue;
}

ProtoTrack onetrack;
onetrack.reserve(x.size());

std::copy(x.begin(), x.end(), std::back_inserter(onetrack));
protoTracks.push_back(std::move(onetrack));
}

ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits");
ACTS_INFO("Removed " << nShortTracks << " with less then "
<< m_cfg.minMeasurementsPerTrack << " hits");
ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
m_outputProtoTracks(ctx, std::move(protoTracks));

Expand All @@ -253,27 +243,33 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
m_outputGraph(ctx, {graph.first, graph.second});
}

auto t3 = Clock::now();
m_timing.preprocessingTime(Duration(t1 - t0).count());
m_timing.postprocessingTime(Duration(t3 - t2).count());

return ActsExamples::ProcessCode::SUCCESS;
}

ActsExamples::ProcessCode TrackFindingAlgorithmExaTrkX::finalize() {
namespace ba = boost::accumulators;

auto print = [](const auto& t) {
std::stringstream ss;
ss << ba::mean(t) << " +- " << std::sqrt(ba::variance(t)) << " ";
ss << "[" << ba::min(t) << ", " << ba::max(t) << "]";
return ss.str();
};

ACTS_INFO("Exa.TrkX timing info");
{
const auto& t = m_timing.graphBuildingTime;
ACTS_INFO("- graph building: " << ba::mean(t) << " +- "
<< std::sqrt(ba::variance(t)));
}
ACTS_INFO("- preprocessing: " << print(m_timing.preprocessingTime));
ACTS_INFO("- graph building: " << print(m_timing.graphBuildingTime));
// clang-format off
for (const auto& t : m_timing.classifierTimes) {
ACTS_INFO("- classifier: " << ba::mean(t) << " +- "
<< std::sqrt(ba::variance(t)));
}
{
const auto& t = m_timing.trackBuildingTime;
ACTS_INFO("- track building: " << ba::mean(t) << " +- "
<< std::sqrt(ba::variance(t)));
ACTS_INFO("- classifier: " << print(t));
}
// clang-format on
ACTS_INFO("- track building: " << print(m_timing.trackBuildingTime));
ACTS_INFO("- postprocessing: " << print(m_timing.postprocessingTime));

return {};
}
Loading

0 comments on commit a0aed8c

Please sign in to comment.