Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update GNN plugin #3876

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "Acts/Plugins/ExaTrkX/Stages.hpp"
#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/GeometryContainers.hpp"
#include "ActsExamples/EventData/Graph.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimHit.hpp"
Expand All @@ -34,24 +35,60 @@ namespace ActsExamples {
class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
public:
enum class NodeFeature {
// SP features
eR,
ePhi,
eX,
eY,
eZ,
eEta,
// Single cluster features
eCellCount,
eCellSum,
eClusterX,
eClusterY,
eChargeSum,
eClusterLoc0,
eClusterLoc1,
// Cluster 1 features
eCluster1X,
eCluster1Y,
eCluster1Z,
eCluster1R,
eCluster2R,
eCluster1Phi,
eCluster2Phi,
eCluster1Z,
eCluster2Z,
eCluster1Eta,
eCellCount1,
eChargeSum1,
eLocEta1,
eLocPhi1,
eLocDir01,
eLocDir11,
eLocDir21,
eLengthDir01,
eLengthDir11,
eLengthDir21,
eGlobEta1,
eGlobPhi1,
eEtaAngle1,
ePhiAngle1,
// Cluster 2 features
eCluster2X,
eCluster2Y,
eCluster2Z,
eCluster2R,
eCluster2Phi,
eCluster2Eta,
eCellCount2,
eChargeSum2,
eLocEta2,
eLocPhi2,
eLocDir02,
eLocDir12,
eLocDir22,
eLengthDir02,
eLengthDir12,
eLengthDir22,
eGlobEta2,
eGlobPhi2,
eEtaAngle2,
ePhiAngle2,
};

struct Config {
Expand All @@ -63,14 +100,16 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
std::string inputTruthGraph;
/// Output prototracks
std::string outputProtoTracks;

/// Output graph (optional)
std::string outputGraph;

/// Graph constructor
std::shared_ptr<Acts::GraphConstructionBase> graphConstructor;

/// List of edge classifiers
std::vector<std::shared_ptr<Acts::EdgeClassificationBase>> edgeClassifiers;

/// The track builder
std::shared_ptr<Acts::TrackBuildingBase> trackBuilder;

/// Node features
Expand All @@ -81,7 +120,10 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
std::vector<float> featureScales = {1.f, 1.f, 1.f};

/// Remove track candidates with 2 or less hits
bool filterShortTracks = false;
std::size_t minMeasurementsPerTrack = 3;

/// Optionally remap the geometry Ids that are put into the chain
std::shared_ptr<GeometryIdMapActsAthena> geometryIdMap;
};

/// Constructor of the track finding algorithm
Expand Down Expand Up @@ -111,13 +153,17 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
mutable std::mutex m_mutex;

using Accumulator = boost::accumulators::accumulator_set<
float, boost::accumulators::features<boost::accumulators::tag::mean,
boost::accumulators::tag::variance>>;
float,
boost::accumulators::features<
boost::accumulators::tag::mean, boost::accumulators::tag::variance,
boost::accumulators::tag::max, boost::accumulators::tag::min>>;

mutable struct {
Accumulator preprocessingTime;
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
Accumulator graphBuildingTime;
std::vector<Accumulator> classifierTimes;
Accumulator trackBuildingTime;
Accumulator postprocessingTime;
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
} m_timing;

ReadDataHandle<SimSpacePointContainer> m_inputSpacePoints{this,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
#include "ActsExamples/Framework/WhiteBoard.hpp"

#include <algorithm>
#include <chrono>
#include <numeric>

#include "createFeatures.hpp"

using namespace ActsExamples;
using namespace Acts::UnitLiterals;

Expand Down Expand Up @@ -87,8 +90,9 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(

// Check if we want cluster features but do not have them
const static std::array clFeatures = {
NodeFeature::eClusterX, NodeFeature::eClusterY, NodeFeature::eCellCount,
NodeFeature::eCellSum, NodeFeature::eCluster1R, NodeFeature::eCluster2R};
NodeFeature::eClusterLoc0, NodeFeature::eClusterLoc0,
NodeFeature::eCellCount, NodeFeature::eChargeSum,
NodeFeature::eCluster1R, NodeFeature::eCluster2R};

auto wantClFeatures = std::ranges::any_of(
m_cfg.nodeFeatures,
Expand All @@ -108,6 +112,10 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(

ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
const ActsExamples::AlgorithmContext& ctx) const {
using Clock = std::chrono::high_resolution_clock;
using Duration = std::chrono::duration<double, std::milli>;
auto t0 = Clock::now();

// Setup hooks
LoopHook hook;

Expand Down Expand Up @@ -139,10 +147,11 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
ACTS_DEBUG("Received " << numSpacepoints << " spacepoints");
ACTS_DEBUG("Construct " << numFeatures << " node features");

std::vector<float> features(numSpacepoints * numFeatures);
std::vector<int> spacepointIDs;
std::vector<std::uint64_t> moduleIds;

spacepointIDs.reserve(spacepoints.size());
moduleIds.reserve(spacepoints.size());

for (auto isp = 0ul; isp < numSpacepoints; ++isp) {
const auto& sp = spacepoints[isp];
Expand All @@ -157,54 +166,25 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
// to the pipeline
spacepointIDs.push_back(isp);

// This should be fine, because check in constructor
Cluster* cl1 = clusters ? &clusters->at(sl1.index()) : nullptr;
Cluster* cl2 = cl1;

if (sp.sourceLinks().size() == 2) {
const auto& sl2 = sp.sourceLinks()[1].template get<IndexSourceLink>();
cl2 = clusters ? &clusters->at(sl2.index()) : nullptr;
if (m_cfg.geometryIdMap != nullptr) {
moduleIds.push_back(m_cfg.geometryIdMap->right.at(sl1.geometryId()));
} else {
moduleIds.push_back(sl1.geometryId().value());
}
}

// I would prefer to use a std::span or boost::span here once available
float* f = features.data() + isp * numFeatures;

using NF = NodeFeature;

for (auto ift = 0ul; ift < numFeatures; ++ift) {
// clang-format off
switch(m_cfg.nodeFeatures[ift]) {
break; case NF::eR: f[ift] = std::hypot(sp.x(), sp.y());
break; case NF::ePhi: f[ift] = std::atan2(sp.y(), sp.x());
break; case NF::eZ: f[ift] = sp.z();
break; case NF::eX: f[ift] = sp.x();
break; case NF::eY: f[ift] = sp.y();
break; case NF::eEta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{sp.x(), sp.y(), sp.z()});
break; case NF::eClusterX: f[ift] = cl1->sizeLoc0;
break; case NF::eClusterY: f[ift] = cl1->sizeLoc1;
break; case NF::eCellSum: f[ift] = cl1->sumActivations();
break; case NF::eCellCount: f[ift] = cl1->channels.size();
break; case NF::eCluster1R: f[ift] = std::hypot(cl1->globalPosition[Acts::ePos0], cl1->globalPosition[Acts::ePos1]);
break; case NF::eCluster2R: f[ift] = std::hypot(cl2->globalPosition[Acts::ePos0], cl2->globalPosition[Acts::ePos1]);
break; case NF::eCluster1Phi: f[ift] = std::atan2(cl1->globalPosition[Acts::ePos1], cl1->globalPosition[Acts::ePos0]);
break; case NF::eCluster2Phi: f[ift] = std::atan2(cl2->globalPosition[Acts::ePos1], cl2->globalPosition[Acts::ePos0]);
break; case NF::eCluster1Z: f[ift] = cl1->globalPosition[Acts::ePos2];
break; case NF::eCluster2Z: f[ift] = cl2->globalPosition[Acts::ePos2];
break; case NF::eCluster1Eta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{cl1->globalPosition[Acts::ePos0], cl1->globalPosition[Acts::ePos1], cl1->globalPosition[Acts::ePos2]});
break; case NF::eCluster2Eta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{cl2->globalPosition[Acts::ePos0], cl2->globalPosition[Acts::ePos1], cl2->globalPosition[Acts::ePos2]});
}
// clang-format on
auto features = createFeatures(spacepoints, clusters, m_cfg.nodeFeatures,
m_cfg.featureScales);

f[ift] /= m_cfg.featureScales[ift];
}
}
auto t1 = Clock::now();

// Run the pipeline
const auto trackCandidates = [&]() {
std::lock_guard<std::mutex> lock(m_mutex);

Acts::ExaTrkXTiming timing;
auto res = m_pipeline.run(features, spacepointIDs, hook, &timing);
auto res =
m_pipeline.run(features, moduleIds, spacepointIDs, hook, &timing);

m_timing.graphBuildingTime(timing.graphBuildingTime.count());

Expand All @@ -219,6 +199,8 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
return res;
}();

auto t2 = Clock::now();

ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
<< " candidates");

Expand All @@ -228,20 +210,28 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(

int nShortTracks = 0;

for (auto& x : trackCandidates) {
if (m_cfg.filterShortTracks && x.size() < 3) {
/// TODO the whole conversion back to meas idxs should be pulled out of the
/// track trackBuilder
for (auto& candidate : trackCandidates) {
ProtoTrack onetrack;
onetrack.reserve(candidate.size());

for (auto i : candidate) {
for (const auto& sl : spacepoints[i].sourceLinks()) {
onetrack.push_back(sl.template get<IndexSourceLink>().index());
}
}

if (onetrack.size() < m_cfg.minMeasurementsPerTrack) {
nShortTracks++;
continue;
}

ProtoTrack onetrack;
onetrack.reserve(x.size());

std::copy(x.begin(), x.end(), std::back_inserter(onetrack));
protoTracks.push_back(std::move(onetrack));
}

ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits");
ACTS_INFO("Removed " << nShortTracks << " with less then "
<< m_cfg.minMeasurementsPerTrack << " hits");
ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
m_outputProtoTracks(ctx, std::move(protoTracks));

Expand All @@ -253,27 +243,33 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
m_outputGraph(ctx, {graph.first, graph.second});
}

auto t3 = Clock::now();
m_timing.preprocessingTime(Duration(t1 - t0).count());
m_timing.postprocessingTime(Duration(t3 - t2).count());

return ActsExamples::ProcessCode::SUCCESS;
}

ActsExamples::ProcessCode TrackFindingAlgorithmExaTrkX::finalize() {
namespace ba = boost::accumulators;

auto print = [](const auto& t) {
std::stringstream ss;
ss << ba::mean(t) << " +- " << std::sqrt(ba::variance(t)) << " ";
ss << "[" << ba::min(t) << ", " << ba::max(t) << "]";
return ss.str();
};

ACTS_INFO("Exa.TrkX timing info");
{
const auto& t = m_timing.graphBuildingTime;
ACTS_INFO("- graph building: " << ba::mean(t) << " +- "
<< std::sqrt(ba::variance(t)));
}
ACTS_INFO("- preprocessing: " << print(m_timing.preprocessingTime));
ACTS_INFO("- graph building: " << print(m_timing.graphBuildingTime));
// clang-format off
for (const auto& t : m_timing.classifierTimes) {
ACTS_INFO("- classifier: " << ba::mean(t) << " +- "
<< std::sqrt(ba::variance(t)));
}
{
const auto& t = m_timing.trackBuildingTime;
ACTS_INFO("- track building: " << ba::mean(t) << " +- "
<< std::sqrt(ba::variance(t)));
ACTS_INFO("- classifier: " << print(t));
}
// clang-format on
ACTS_INFO("- track building: " << print(m_timing.trackBuildingTime));
ACTS_INFO("- postprocessing: " << print(m_timing.postprocessingTime));

return {};
}
Loading
Loading