From a0aed8cacd5709106ec5d3ba3a76d915916ce46c Mon Sep 17 00:00:00 2001 From: Benjamin Huth <37871400+benjaminhuth@users.noreply.github.com> Date: Tue, 26 Nov 2024 01:15:21 +0100 Subject: [PATCH] feat: Update GNN plugin (#3876) This updates the GNN plugin from my dev branch. Contains: * Enhance interfaces of pipeline stages to handle edge features and module-based graph construction * Add a lot more hit-features * Refactor feature building into separate function * Change feature selection in for pipeline stages * More detailed time measurements * Start to refactor ONNX edge classification --- .../TrackFindingAlgorithmExaTrkX.hpp | 68 +++++-- .../src/TrackFindingAlgorithmExaTrkX.cpp | 118 ++++++------ .../src/createFeatures.cpp | 108 +++++++++++ .../src/createFeatures.hpp | 21 ++ Examples/Python/src/ExaTrkXTrackFinding.cpp | 94 +++++---- .../Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp | 1 + .../Plugins/ExaTrkX/OnnxEdgeClassifier.hpp | 9 +- .../Plugins/ExaTrkX/OnnxMetricLearning.hpp | 3 +- .../include/Acts/Plugins/ExaTrkX/Stages.hpp | 31 +-- .../Plugins/ExaTrkX/TorchEdgeClassifier.hpp | 7 +- .../Plugins/ExaTrkX/TorchMetricLearning.hpp | 8 +- .../Acts/Plugins/ExaTrkX/detail/Utils.hpp | 51 +++++ Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp | 82 ++++---- Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp | 179 ++++++++++++------ Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp | 86 ++++++--- Plugins/ExaTrkX/src/TorchMetricLearning.cpp | 21 +- .../src/TorchTruthGraphMetricsHook.cpp | 26 ++- Plugins/ExaTrkX/src/printCudaMemInfo.hpp | 2 +- 18 files changed, 659 insertions(+), 256 deletions(-) create mode 100644 Examples/Algorithms/TrackFindingExaTrkX/src/createFeatures.cpp create mode 100644 Examples/Algorithms/TrackFindingExaTrkX/src/createFeatures.hpp create mode 100644 Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/detail/Utils.hpp diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp index f85a000aac7..79ee477f662 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp @@ -13,6 +13,7 @@ #include "Acts/Plugins/ExaTrkX/Stages.hpp" #include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "ActsExamples/EventData/Cluster.hpp" +#include "ActsExamples/EventData/GeometryContainers.hpp" #include "ActsExamples/EventData/Graph.hpp" #include "ActsExamples/EventData/ProtoTrack.hpp" #include "ActsExamples/EventData/SimHit.hpp" @@ -34,24 +35,60 @@ namespace ActsExamples { class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { public: enum class NodeFeature { + // SP features eR, ePhi, eX, eY, eZ, eEta, + // Single cluster features eCellCount, - eCellSum, - eClusterX, - eClusterY, + eChargeSum, + eClusterLoc0, + eClusterLoc1, + // Cluster 1 features + eCluster1X, + eCluster1Y, + eCluster1Z, eCluster1R, - eCluster2R, eCluster1Phi, - eCluster2Phi, - eCluster1Z, - eCluster2Z, eCluster1Eta, + eCellCount1, + eChargeSum1, + eLocEta1, + eLocPhi1, + eLocDir01, + eLocDir11, + eLocDir21, + eLengthDir01, + eLengthDir11, + eLengthDir21, + eGlobEta1, + eGlobPhi1, + eEtaAngle1, + ePhiAngle1, + // Cluster 2 features + eCluster2X, + eCluster2Y, + eCluster2Z, + eCluster2R, + eCluster2Phi, eCluster2Eta, + eCellCount2, + eChargeSum2, + eLocEta2, + eLocPhi2, + eLocDir02, + eLocDir12, + eLocDir22, + eLengthDir02, + eLengthDir12, + eLengthDir22, + eGlobEta2, + eGlobPhi2, + eEtaAngle2, + ePhiAngle2, }; struct Config { @@ -63,14 +100,16 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { std::string inputTruthGraph; /// Output prototracks std::string outputProtoTracks; - /// Output graph (optional) std::string outputGraph; + /// Graph constructor std::shared_ptr graphConstructor; + /// List of edge classifiers std::vector> edgeClassifiers; + /// The track builder std::shared_ptr trackBuilder; /// Node features @@ -81,7 +120,10 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { std::vector featureScales = {1.f, 1.f, 1.f}; /// Remove track candidates with 2 or less hits - bool filterShortTracks = false; + std::size_t minMeasurementsPerTrack = 3; + + /// Optionally remap the geometry Ids that are put into the chain + std::shared_ptr geometryIdMap; }; /// Constructor of the track finding algorithm @@ -111,13 +153,17 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { mutable std::mutex m_mutex; using Accumulator = boost::accumulators::accumulator_set< - float, boost::accumulators::features>; + float, + boost::accumulators::features< + boost::accumulators::tag::mean, boost::accumulators::tag::variance, + boost::accumulators::tag::max, boost::accumulators::tag::min>>; mutable struct { + Accumulator preprocessingTime; Accumulator graphBuildingTime; std::vector classifierTimes; Accumulator trackBuildingTime; + Accumulator postprocessingTime; } m_timing; ReadDataHandle m_inputSpacePoints{this, diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index df4b8998540..858b90359da 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -20,8 +20,11 @@ #include "ActsExamples/Framework/WhiteBoard.hpp" #include +#include #include +#include "createFeatures.hpp" + using namespace ActsExamples; using namespace Acts::UnitLiterals; @@ -87,8 +90,9 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( // Check if we want cluster features but do not have them const static std::array clFeatures = { - NodeFeature::eClusterX, NodeFeature::eClusterY, NodeFeature::eCellCount, - NodeFeature::eCellSum, NodeFeature::eCluster1R, NodeFeature::eCluster2R}; + NodeFeature::eClusterLoc0, NodeFeature::eClusterLoc0, + NodeFeature::eCellCount, NodeFeature::eChargeSum, + NodeFeature::eCluster1R, NodeFeature::eCluster2R}; auto wantClFeatures = std::ranges::any_of( m_cfg.nodeFeatures, @@ -108,6 +112,10 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( const ActsExamples::AlgorithmContext& ctx) const { + using Clock = std::chrono::high_resolution_clock; + using Duration = std::chrono::duration; + auto t0 = Clock::now(); + // Setup hooks LoopHook hook; @@ -139,10 +147,11 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( ACTS_DEBUG("Received " << numSpacepoints << " spacepoints"); ACTS_DEBUG("Construct " << numFeatures << " node features"); - std::vector features(numSpacepoints * numFeatures); std::vector spacepointIDs; + std::vector moduleIds; spacepointIDs.reserve(spacepoints.size()); + moduleIds.reserve(spacepoints.size()); for (auto isp = 0ul; isp < numSpacepoints; ++isp) { const auto& sp = spacepoints[isp]; @@ -157,54 +166,25 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( // to the pipeline spacepointIDs.push_back(isp); - // This should be fine, because check in constructor - Cluster* cl1 = clusters ? &clusters->at(sl1.index()) : nullptr; - Cluster* cl2 = cl1; - - if (sp.sourceLinks().size() == 2) { - const auto& sl2 = sp.sourceLinks()[1].template get(); - cl2 = clusters ? &clusters->at(sl2.index()) : nullptr; + if (m_cfg.geometryIdMap != nullptr) { + moduleIds.push_back(m_cfg.geometryIdMap->right.at(sl1.geometryId())); + } else { + moduleIds.push_back(sl1.geometryId().value()); } + } - // I would prefer to use a std::span or boost::span here once available - float* f = features.data() + isp * numFeatures; - - using NF = NodeFeature; - - for (auto ift = 0ul; ift < numFeatures; ++ift) { - // clang-format off - switch(m_cfg.nodeFeatures[ift]) { - break; case NF::eR: f[ift] = std::hypot(sp.x(), sp.y()); - break; case NF::ePhi: f[ift] = std::atan2(sp.y(), sp.x()); - break; case NF::eZ: f[ift] = sp.z(); - break; case NF::eX: f[ift] = sp.x(); - break; case NF::eY: f[ift] = sp.y(); - break; case NF::eEta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{sp.x(), sp.y(), sp.z()}); - break; case NF::eClusterX: f[ift] = cl1->sizeLoc0; - break; case NF::eClusterY: f[ift] = cl1->sizeLoc1; - break; case NF::eCellSum: f[ift] = cl1->sumActivations(); - break; case NF::eCellCount: f[ift] = cl1->channels.size(); - break; case NF::eCluster1R: f[ift] = std::hypot(cl1->globalPosition[Acts::ePos0], cl1->globalPosition[Acts::ePos1]); - break; case NF::eCluster2R: f[ift] = std::hypot(cl2->globalPosition[Acts::ePos0], cl2->globalPosition[Acts::ePos1]); - break; case NF::eCluster1Phi: f[ift] = std::atan2(cl1->globalPosition[Acts::ePos1], cl1->globalPosition[Acts::ePos0]); - break; case NF::eCluster2Phi: f[ift] = std::atan2(cl2->globalPosition[Acts::ePos1], cl2->globalPosition[Acts::ePos0]); - break; case NF::eCluster1Z: f[ift] = cl1->globalPosition[Acts::ePos2]; - break; case NF::eCluster2Z: f[ift] = cl2->globalPosition[Acts::ePos2]; - break; case NF::eCluster1Eta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{cl1->globalPosition[Acts::ePos0], cl1->globalPosition[Acts::ePos1], cl1->globalPosition[Acts::ePos2]}); - break; case NF::eCluster2Eta: f[ift] = Acts::VectorHelpers::eta(Acts::Vector3{cl2->globalPosition[Acts::ePos0], cl2->globalPosition[Acts::ePos1], cl2->globalPosition[Acts::ePos2]}); - } - // clang-format on + auto features = createFeatures(spacepoints, clusters, m_cfg.nodeFeatures, + m_cfg.featureScales); - f[ift] /= m_cfg.featureScales[ift]; - } - } + auto t1 = Clock::now(); // Run the pipeline const auto trackCandidates = [&]() { std::lock_guard lock(m_mutex); Acts::ExaTrkXTiming timing; - auto res = m_pipeline.run(features, spacepointIDs, hook, &timing); + auto res = + m_pipeline.run(features, moduleIds, spacepointIDs, hook, &timing); m_timing.graphBuildingTime(timing.graphBuildingTime.count()); @@ -219,6 +199,8 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( return res; }(); + auto t2 = Clock::now(); + ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size() << " candidates"); @@ -228,20 +210,28 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( int nShortTracks = 0; - for (auto& x : trackCandidates) { - if (m_cfg.filterShortTracks && x.size() < 3) { + /// TODO the whole conversion back to meas idxs should be pulled out of the + /// track trackBuilder + for (auto& candidate : trackCandidates) { + ProtoTrack onetrack; + onetrack.reserve(candidate.size()); + + for (auto i : candidate) { + for (const auto& sl : spacepoints[i].sourceLinks()) { + onetrack.push_back(sl.template get().index()); + } + } + + if (onetrack.size() < m_cfg.minMeasurementsPerTrack) { nShortTracks++; continue; } - ProtoTrack onetrack; - onetrack.reserve(x.size()); - - std::copy(x.begin(), x.end(), std::back_inserter(onetrack)); protoTracks.push_back(std::move(onetrack)); } - ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits"); + ACTS_INFO("Removed " << nShortTracks << " with less then " + << m_cfg.minMeasurementsPerTrack << " hits"); ACTS_INFO("Created " << protoTracks.size() << " proto tracks"); m_outputProtoTracks(ctx, std::move(protoTracks)); @@ -253,27 +243,33 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( m_outputGraph(ctx, {graph.first, graph.second}); } + auto t3 = Clock::now(); + m_timing.preprocessingTime(Duration(t1 - t0).count()); + m_timing.postprocessingTime(Duration(t3 - t2).count()); + return ActsExamples::ProcessCode::SUCCESS; } ActsExamples::ProcessCode TrackFindingAlgorithmExaTrkX::finalize() { namespace ba = boost::accumulators; + auto print = [](const auto& t) { + std::stringstream ss; + ss << ba::mean(t) << " +- " << std::sqrt(ba::variance(t)) << " "; + ss << "[" << ba::min(t) << ", " << ba::max(t) << "]"; + return ss.str(); + }; + ACTS_INFO("Exa.TrkX timing info"); - { - const auto& t = m_timing.graphBuildingTime; - ACTS_INFO("- graph building: " << ba::mean(t) << " +- " - << std::sqrt(ba::variance(t))); - } + ACTS_INFO("- preprocessing: " << print(m_timing.preprocessingTime)); + ACTS_INFO("- graph building: " << print(m_timing.graphBuildingTime)); + // clang-format off for (const auto& t : m_timing.classifierTimes) { - ACTS_INFO("- classifier: " << ba::mean(t) << " +- " - << std::sqrt(ba::variance(t))); - } - { - const auto& t = m_timing.trackBuildingTime; - ACTS_INFO("- track building: " << ba::mean(t) << " +- " - << std::sqrt(ba::variance(t))); + ACTS_INFO("- classifier: " << print(t)); } + // clang-format on + ACTS_INFO("- track building: " << print(m_timing.trackBuildingTime)); + ACTS_INFO("- postprocessing: " << print(m_timing.postprocessingTime)); return {}; } diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/createFeatures.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/createFeatures.cpp new file mode 100644 index 00000000000..426a27428f0 --- /dev/null +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/createFeatures.cpp @@ -0,0 +1,108 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "createFeatures.hpp" + +#include "Acts/Utilities/AngleHelpers.hpp" +#include "Acts/Utilities/VectorHelpers.hpp" + +namespace ActsExamples { + +std::vector createFeatures( + const SimSpacePointContainer& spacepoints, + const std::optional& clusters, + const std::vector& nodeFeatures, + const std::vector& featureScales) { + using namespace ActsExamples; + + assert(nodeFeatures.size() == featureScales.size()); + std::vector features(spacepoints.size() * nodeFeatures.size()); + + for (auto isp = 0ul; isp < spacepoints.size(); ++isp) { + const auto& sp = spacepoints[isp]; + + // For now just take the first index since does require one single index + // per spacepoint + // TODO does it work for the module map construction to use only the first + // sp? + const auto& sl1 = sp.sourceLinks()[0].template get(); + + // This should be fine, because check in constructor + const Cluster* cl1 = clusters ? &clusters->at(sl1.index()) : nullptr; + const Cluster* cl2 = cl1; + + if (sp.sourceLinks().size() == 2) { + const auto& sl2 = sp.sourceLinks()[1].template get(); + cl2 = clusters ? &clusters->at(sl2.index()) : nullptr; + } + + // I would prefer to use a std::span or boost::span here once available + float* f = features.data() + isp * nodeFeatures.size(); + + using NF = TrackFindingAlgorithmExaTrkX::NodeFeature; + + using namespace Acts::VectorHelpers; + using namespace Acts::AngleHelpers; + + // clang-format off +#define MAKE_CLUSTER_FEATURES(n) \ + break; case NF::eCluster##n##X: f[ift] = cl##n->globalPosition[Acts::ePos0]; \ + break; case NF::eCluster##n##Y: f[ift] = cl##n->globalPosition[Acts::ePos1]; \ + break; case NF::eCluster##n##R: f[ift] = perp(cl##n->globalPosition); \ + break; case NF::eCluster##n##Phi: f[ift] = phi(cl##n->globalPosition); \ + break; case NF::eCluster##n##Z: f[ift] = cl##n->globalPosition[Acts::ePos2]; \ + break; case NF::eCluster##n##Eta: f[ift] = eta(cl##n->globalPosition); \ + break; case NF::eCellCount##n: f[ift] = cl##n->channels.size(); \ + break; case NF::eChargeSum##n: f[ift] = cl##n->sumActivations(); \ + break; case NF::eLocDir0##n: f[ift] = cl##n->localDirection[0]; \ + break; case NF::eLocDir1##n: f[ift] = cl##n->localDirection[1]; \ + break; case NF::eLocDir2##n: f[ift] = cl##n->localDirection[2]; \ + break; case NF::eLengthDir0##n: f[ift] = cl##n->lengthDirection[0]; \ + break; case NF::eLengthDir1##n: f[ift] = cl##n->lengthDirection[1]; \ + break; case NF::eLengthDir2##n: f[ift] = cl##n->lengthDirection[2]; \ + break; case NF::eLocEta##n: f[ift] = cl##n->localEta; \ + break; case NF::eLocPhi##n: f[ift] = cl##n->localPhi; \ + break; case NF::eGlobEta##n: f[ift] = cl##n->globalEta; \ + break; case NF::eGlobPhi##n: f[ift] = cl##n->globalPhi; \ + break; case NF::eEtaAngle##n: f[ift] = cl##n->etaAngle; \ + break; case NF::ePhiAngle##n: f[ift] = cl##n->phiAngle; + // clang-format on + + Acts::Vector3 spPos{sp.x(), sp.y(), sp.z()}; + + for (auto ift = 0ul; ift < nodeFeatures.size(); ++ift) { + // clang-format off + switch(nodeFeatures[ift]) { + // Spacepoint features + break; case NF::eR: f[ift] = perp(spPos); + break; case NF::ePhi: f[ift] = phi(spPos); + break; case NF::eZ: f[ift] = sp.z(); + break; case NF::eX: f[ift] = sp.x(); + break; case NF::eY: f[ift] = sp.y(); + break; case NF::eEta: f[ift] = eta(spPos); + // Single cluster features + break; case NF::eClusterLoc0: f[ift] = cl1->sizeLoc0; + break; case NF::eClusterLoc1: f[ift] = cl1->sizeLoc1; + break; case NF::eCellCount: f[ift] = cl1->channels.size(); + break; case NF::eChargeSum: f[ift] = cl1->sumActivations(); + // Features for split clusters + MAKE_CLUSTER_FEATURES(1) + MAKE_CLUSTER_FEATURES(2) + } + // clang-format on + + assert(std::isfinite(f[ift])); + f[ift] /= featureScales[ift]; + } +#undef MAKE_CLUSTER_FEATURES + } + + return features; +} + +} // namespace ActsExamples diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/createFeatures.hpp b/Examples/Algorithms/TrackFindingExaTrkX/src/createFeatures.hpp new file mode 100644 index 00000000000..dede5f3c11c --- /dev/null +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/createFeatures.hpp @@ -0,0 +1,21 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp" + +namespace ActsExamples { + +std::vector createFeatures( + const SimSpacePointContainer &spacepoints, + const std::optional &clusters, + const std::vector &nodeFeatures, + const std::vector &featureScales); + +} diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index 3778e8576f5..c7a30b51d2e 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -30,6 +30,7 @@ namespace py = pybind11; using namespace ActsExamples; using namespace Acts; +using namespace py::literals; namespace Acts::Python { @@ -67,7 +68,7 @@ void addExaTrkXTrackFinding(Context &ctx) { auto c = py::class_(alg, "Config").def(py::init<>()); ACTS_PYTHON_STRUCT_BEGIN(c, Config); ACTS_PYTHON_MEMBER(modelPath); - ACTS_PYTHON_MEMBER(numFeatures); + ACTS_PYTHON_MEMBER(selectedFeatures); ACTS_PYTHON_MEMBER(embeddingDim); ACTS_PYTHON_MEMBER(rVal); ACTS_PYTHON_MEMBER(knnVal); @@ -91,11 +92,12 @@ void addExaTrkXTrackFinding(Context &ctx) { auto c = py::class_(alg, "Config").def(py::init<>()); ACTS_PYTHON_STRUCT_BEGIN(c, Config); ACTS_PYTHON_MEMBER(modelPath); - ACTS_PYTHON_MEMBER(numFeatures); + ACTS_PYTHON_MEMBER(selectedFeatures); ACTS_PYTHON_MEMBER(cut); ACTS_PYTHON_MEMBER(nChunks); ACTS_PYTHON_MEMBER(undirected); ACTS_PYTHON_MEMBER(deviceID); + ACTS_PYTHON_MEMBER(useEdgeFeatures); ACTS_PYTHON_STRUCT_END(); } { @@ -174,36 +176,61 @@ void addExaTrkXTrackFinding(Context &ctx) { inputMeasurementSimHitsMap, inputMeasurementParticlesMap, outputGraph, targetMinPT, targetMinSize, uniqueModules); - py::enum_(mex, "NodeFeature") - .value("R", TrackFindingAlgorithmExaTrkX::NodeFeature::eR) - .value("Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::ePhi) - .value("Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eZ) - .value("X", TrackFindingAlgorithmExaTrkX::NodeFeature::eX) - .value("Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eY) - .value("Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eEta) - .value("ClusterX", TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterX) - .value("ClusterY", TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterY) - .value("CellCount", TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount) - .value("CellSum", TrackFindingAlgorithmExaTrkX::NodeFeature::eCellSum) - .value("Cluster1R", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster1R) - .value("Cluster2R", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster2R) - .value("Cluster1Phi", - TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster1Phi) - .value("Cluster2Phi", - TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster2Phi) - .value("Cluster1Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster1Z) - .value("Cluster2Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster2Z) - .value("Cluster1Eta", - TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster1Eta) - .value("Cluster2Eta", - TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster2Eta); - - ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::TrackFindingAlgorithmExaTrkX, mex, - "TrackFindingAlgorithmExaTrkX", - inputSpacePoints, inputClusters, - inputTruthGraph, outputProtoTracks, outputGraph, - graphConstructor, edgeClassifiers, trackBuilder, - nodeFeatures, featureScales, filterShortTracks); + { + auto nodeFeatureEnum = + py::enum_(mex, "NodeFeature") + .value("R", TrackFindingAlgorithmExaTrkX::NodeFeature::eR) + .value("Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::ePhi) + .value("Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eZ) + .value("X", TrackFindingAlgorithmExaTrkX::NodeFeature::eX) + .value("Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eY) + .value("Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eEta) + .value("ClusterX", + TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterLoc0) + .value("ClusterY", + TrackFindingAlgorithmExaTrkX::NodeFeature::eClusterLoc1) + .value("CellCount", + TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount) + .value("ChargeSum", + TrackFindingAlgorithmExaTrkX::NodeFeature::eChargeSum); + + // clang-format off +#define ADD_FEATURE_ENUMS(n) \ + nodeFeatureEnum \ + .value("Cluster" #n "X", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##X) \ + .value("Cluster" #n "Y", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Y) \ + .value("Cluster" #n "Z", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Z) \ + .value("Cluster" #n "R", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##R) \ + .value("Cluster" #n "Phi", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Phi) \ + .value("Cluster" #n "Eta", TrackFindingAlgorithmExaTrkX::NodeFeature::eCluster##n##Eta) \ + .value("CellCount" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eCellCount##n) \ + .value("ChargeSum" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eChargeSum##n) \ + .value("LocEta" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocEta##n) \ + .value("LocPhi" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocPhi##n) \ + .value("LocDir0" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir0##n) \ + .value("LocDir1" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir1##n) \ + .value("LocDir2" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLocDir2##n) \ + .value("LengthDir0" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir0##n) \ + .value("LengthDir1" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir1##n) \ + .value("LengthDir2" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eLengthDir2##n) \ + .value("GlobEta" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eGlobEta##n) \ + .value("GlobPhi" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eGlobPhi##n) \ + .value("EtaAngle" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::eEtaAngle##n) \ + .value("PhiAngle" #n, TrackFindingAlgorithmExaTrkX::NodeFeature::ePhiAngle##n) + // clang-format on + + ADD_FEATURE_ENUMS(1); + ADD_FEATURE_ENUMS(2); + +#undef ADD_FEATURE_ENUMS + } + + ACTS_PYTHON_DECLARE_ALGORITHM( + ActsExamples::TrackFindingAlgorithmExaTrkX, mex, + "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputClusters, + inputTruthGraph, outputProtoTracks, outputGraph, graphConstructor, + edgeClassifiers, trackBuilder, nodeFeatures, featureScales, + minMeasurementsPerTrack, geometryIdMap); { auto cls = @@ -239,7 +266,8 @@ void addExaTrkXTrackFinding(Context &ctx) { py::arg("graphConstructor"), py::arg("edgeClassifiers"), py::arg("trackBuilder"), py::arg("level")) .def("run", &ExaTrkXPipeline::run, py::arg("features"), - py::arg("spacepoints"), py::arg("hook") = Acts::ExaTrkXHook{}, + py::arg("moduleIds"), py::arg("spacepoints"), + py::arg("hook") = Acts::ExaTrkXHook{}, py::arg("timing") = nullptr); } diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp index 60d822e4b1f..1e5d1335a72 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp @@ -46,6 +46,7 @@ class ExaTrkXPipeline { std::unique_ptr logger); std::vector> run(std::vector &features, + const std::vector &moduleIds, std::vector &spacepointIDs, const ExaTrkXHook &hook = {}, ExaTrkXTiming *timing = nullptr) const; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp index d672acfe78e..5682b7e84a9 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp @@ -33,8 +33,8 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase { OnnxEdgeClassifier(const Config &cfg, std::unique_ptr logger); ~OnnxEdgeClassifier(); - std::tuple operator()( - std::any nodes, std::any edges, + std::tuple operator()( + std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {}, torch::Device device = torch::Device(torch::kCPU)) override; Config config() const { return m_cfg; } @@ -49,9 +49,8 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase { std::unique_ptr m_env; std::unique_ptr m_model; - std::string m_inputNameNodes; - std::string m_inputNameEdges; - std::string m_outputNameScores; + std::vector m_inputNames; + std::string m_outputName; }; } // namespace Acts diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp index 31c8adeea8e..d78139d0732 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp @@ -36,8 +36,9 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase { OnnxMetricLearning(const Config& cfg, std::unique_ptr logger); ~OnnxMetricLearning(); - std::tuple operator()( + std::tuple operator()( std::vector& inputValues, std::size_t numNodes, + const std::vector& moduleIds, torch::Device device = torch::Device(torch::kCPU)) override; Config config() const { return m_cfg; } diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp index 2cb676b1bbe..1e35fb08a82 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp @@ -9,12 +9,17 @@ #pragma once #include +#include +#include #include #include namespace Acts { +/// Error that is thrown if no edges are found +struct NoEdgesError : std::exception {}; + // TODO maybe replace std::any with some kind of variant, // unique_ptr>? // TODO maybe replace input for GraphConstructionBase with some kind of @@ -27,11 +32,14 @@ class GraphConstructionBase { /// @param inputValues Flattened input data /// @param numNodes Number of nodes. inputValues.size() / numNodes /// then gives the number of features + /// @param moduleIds Module IDs of the features (used for module-map-like + /// graph construction) /// @param device Which GPU device to pick. Not relevant for CPU-only builds /// - /// @return (node_tensor, edge_tensore) - virtual std::tuple operator()( + /// @return (node_features, edge_features, edge_index) + virtual std::tuple operator()( std::vector &inputValues, std::size_t numNodes, + const std::vector &moduleIds, torch::Device device = torch::Device(torch::kCPU)) = 0; virtual torch::Device device() const = 0; @@ -43,13 +51,14 @@ class EdgeClassificationBase { public: /// Perform edge classification /// - /// @param nodes Node tensor with shape (n_nodes, n_node_features) - /// @param edges Edge-index tensor with shape (2, n_edges) + /// @param nodeFeatures Node tensor with shape (n_nodes, n_node_features) + /// @param edgeIndex Edge-index tensor with shape (2, n_edges) + /// @param edgeFeatures Edge-feature tensor with shape (n_edges, n_edge_features) /// @param device Which GPU device to pick. Not relevant for CPU-only builds /// - /// @return (node_tensor, edge_tensor, score_tensor) - virtual std::tuple operator()( - std::any nodes, std::any edges, + /// @return (node_features, edge_features, edge_index, edge_scores) + virtual std::tuple operator()( + std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {}, torch::Device device = torch::Device(torch::kCPU)) = 0; virtual torch::Device device() const = 0; @@ -61,15 +70,15 @@ class TrackBuildingBase { public: /// Perform track building /// - /// @param nodes Node tensor with shape (n_nodes, n_node_features) - /// @param edges Edge-index tensor with shape (2, n_edges) - /// @param edgeWeights Edge-weights of the previous edge classification phase + /// @param nodeFeatures Node tensor with shape (n_nodes, n_node_features) + /// @param edgeIndex Edge-index tensor with shape (2, n_edges) + /// @param edgeScores Scores of the previous edge classification phase /// @param spacepointIDs IDs of the nodes (must have size=n_nodes) /// @param device Which GPU device to pick. Not relevant for CPU-only builds /// /// @return tracks (as vectors of node-IDs) virtual std::vector> operator()( - std::any nodes, std::any edges, std::any edgeWeights, + std::any nodeFeatures, std::any edgeIndex, std::any edgeScores, std::vector &spacepointIDs, torch::Device device = torch::Device(torch::kCPU)) = 0; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp index 3dcae08e0c6..4cf92a7115d 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp @@ -29,18 +29,19 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase { public: struct Config { std::string modelPath; - int numFeatures = 3; + std::vector selectedFeatures = {}; float cut = 0.21; int nChunks = 1; // NOTE for GNN use 1 bool undirected = false; int deviceID = 0; + bool useEdgeFeatures = false; }; TorchEdgeClassifier(const Config &cfg, std::unique_ptr logger); ~TorchEdgeClassifier(); - std::tuple operator()( - std::any nodes, std::any edges, + std::tuple operator()( + std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {}, torch::Device device = torch::Device(torch::kCPU)) override; Config config() const { return m_cfg; } diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp index fe31da1ee3e..9d87e5c59d5 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp @@ -27,19 +27,23 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase { public: struct Config { std::string modelPath; - int numFeatures = 3; + std::vector selectedFeatures = {}; int embeddingDim = 8; float rVal = 1.6; int knnVal = 500; bool shuffleDirections = false; int deviceID = 0; // default is the first GPU if available + + // For edge features + float phiScale = 3.141592654; }; TorchMetricLearning(const Config &cfg, std::unique_ptr logger); ~TorchMetricLearning(); - std::tuple operator()( + std::tuple operator()( std::vector &inputValues, std::size_t numNodes, + const std::vector &moduleIds, torch::Device device = torch::Device(torch::kCPU)) override; Config config() const { return m_cfg; } diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/detail/Utils.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/detail/Utils.hpp new file mode 100644 index 00000000000..64d8830f3c9 --- /dev/null +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/detail/Utils.hpp @@ -0,0 +1,51 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include + +#include + +namespace Acts::detail { + +struct TensorDetails { + const torch::Tensor &tensor; + TensorDetails(const torch::Tensor &t) : tensor(t) {} +}; + +inline std::ostream &operator<<(std::ostream &os, const TensorDetails &t) { + os << t.tensor.dtype() << ", " << t.tensor.sizes(); + if (at::isnan(t.tensor).any().item()) { + os << ", contains NaNs"; + } else { + os << ", no NaNs"; + } + return os; +} + +template +struct RangePrinter { + It begin; + It end; + + RangePrinter(It a, It b) : begin(a), end(b) {} +}; + +template +RangePrinter(It b, It e) -> RangePrinter; + +template +inline std::ostream &operator<<(std::ostream &os, const RangePrinter &r) { + for (auto it = r.begin; it != r.end; ++it) { + os << *it << " "; + } + return os; +} + +} // namespace Acts::detail diff --git a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp index d53de1d3071..6a05ae96f8e 100644 --- a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp +++ b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp @@ -36,52 +36,64 @@ ExaTrkXPipeline::ExaTrkXPipeline( } std::vector> ExaTrkXPipeline::run( - std::vector &features, std::vector &spacepointIDs, - const ExaTrkXHook &hook, ExaTrkXTiming *timing) const { - auto t0 = std::chrono::high_resolution_clock::now(); - auto [nodes, edges] = (*m_graphConstructor)(features, spacepointIDs.size(), - m_graphConstructor->device()); - auto t1 = std::chrono::high_resolution_clock::now(); - - if (timing != nullptr) { - timing->graphBuildingTime = t1 - t0; - } + std::vector &features, const std::vector &moduleIds, + std::vector &spacepointIDs, const ExaTrkXHook &hook, + ExaTrkXTiming *timing) const { + try { + auto t0 = std::chrono::high_resolution_clock::now(); + auto [nodeFeatures, edgeIndex, edgeFeatures] = + (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, + m_graphConstructor->device()); + auto t1 = std::chrono::high_resolution_clock::now(); - hook(nodes, edges, {}); + if (timing != nullptr) { + timing->graphBuildingTime = t1 - t0; + } - std::any edge_weights; - if (timing != nullptr) { + hook(nodeFeatures, edgeIndex, {}); + + std::any edgeScores; timing->classifierTimes.clear(); - } - for (auto edgeClassifier : m_edgeClassifiers) { + for (auto edgeClassifier : m_edgeClassifiers) { + t0 = std::chrono::high_resolution_clock::now(); + auto [newNodeFeatures, newEdgeIndex, newEdgeFeatures, newEdgeScores] = + (*edgeClassifier)(std::move(nodeFeatures), std::move(edgeIndex), + std::move(edgeFeatures), edgeClassifier->device()); + t1 = std::chrono::high_resolution_clock::now(); + + if (timing != nullptr) { + timing->classifierTimes.push_back(t1 - t0); + } + + nodeFeatures = std::move(newNodeFeatures); + edgeFeatures = std::move(newEdgeFeatures); + edgeIndex = std::move(newEdgeIndex); + edgeScores = std::move(newEdgeScores); + + hook(nodeFeatures, edgeIndex, edgeScores); + } + t0 = std::chrono::high_resolution_clock::now(); - auto [newNodes, newEdges, newWeights] = (*edgeClassifier)( - std::move(nodes), std::move(edges), edgeClassifier->device()); + auto res = (*m_trackBuilder)(std::move(nodeFeatures), std::move(edgeIndex), + std::move(edgeScores), spacepointIDs, + m_trackBuilder->device()); t1 = std::chrono::high_resolution_clock::now(); if (timing != nullptr) { - timing->classifierTimes.push_back(t1 - t0); + timing->trackBuildingTime = t1 - t0; } - nodes = std::move(newNodes); - edges = std::move(newEdges); - edge_weights = std::move(newWeights); - - hook(nodes, edges, edge_weights); - } - - t0 = std::chrono::high_resolution_clock::now(); - auto res = (*m_trackBuilder)(std::move(nodes), std::move(edges), - std::move(edge_weights), spacepointIDs, - m_trackBuilder->device()); - t1 = std::chrono::high_resolution_clock::now(); - - if (timing != nullptr) { - timing->trackBuildingTime = t1 - t0; + return res; + } catch (Acts::NoEdgesError &) { + ACTS_WARNING("No egdges left in GNN pipeline, return 0 track candidates"); + if (timing != nullptr) { + while (timing->classifierTimes.size() < m_edgeClassifiers.size()) { + timing->classifierTimes.push_back({}); + } + } + return {}; } - - return res; } } // namespace Acts diff --git a/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp b/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp index ac6d28ce358..59452daec09 100644 --- a/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp @@ -8,11 +8,11 @@ #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp" +#include "Acts/Plugins/ExaTrkX/detail/Utils.hpp" + #include #include -#include "runSessionWithIoBinding.hpp" - using namespace torch::indexing; namespace Acts { @@ -27,80 +27,143 @@ OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg, Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); - session_options.SetGraphOptimizationLevel( - GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + // session_options.SetGraphOptimizationLevel( + // GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + + OrtCUDAProviderOptions cuda_options; + cuda_options.device_id = 0; + // session_options.AppendExecutionProvider_CUDA(cuda_options); m_model = std::make_unique(*m_env, m_cfg.modelPath.c_str(), session_options); Ort::AllocatorWithDefaultOptions allocator; - m_inputNameNodes = - std::string(m_model->GetInputNameAllocated(0, allocator).get()); - m_inputNameEdges = - std::string(m_model->GetInputNameAllocated(1, allocator).get()); - m_outputNameScores = + for (std::size_t i = 0; i < m_model->GetInputCount(); ++i) { + m_inputNames.emplace_back( + m_model->GetInputNameAllocated(i, allocator).get()); + } + m_outputName = std::string(m_model->GetOutputNameAllocated(0, allocator).get()); } OnnxEdgeClassifier::~OnnxEdgeClassifier() {} -std::tuple OnnxEdgeClassifier::operator()( - std::any inputNodes, std::any inputEdges, torch::Device) { - Ort::AllocatorWithDefaultOptions allocator; - auto memoryInfo = Ort::MemoryInfo::CreateCpu( - OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); - - auto eInputTensor = std::any_cast>(inputNodes); - auto edgeList = std::any_cast>(inputEdges); - const int numEdges = edgeList.size() / 2; - - std::vector fInputNames{m_inputNameNodes.c_str(), - m_inputNameEdges.c_str()}; - std::vector fInputTensor; - fInputTensor.push_back(std::move(*eInputTensor)); - std::vector fEdgeShape{2, numEdges}; - fInputTensor.push_back(Ort::Value::CreateTensor( - memoryInfo, edgeList.data(), edgeList.size(), fEdgeShape.data(), - fEdgeShape.size())); - - // filtering outputs - std::vector fOutputNames{m_outputNameScores.c_str()}; - std::vector fOutputData(numEdges); - - auto outputDims = m_model->GetOutputTypeInfo(0) - .GetTensorTypeAndShapeInfo() - .GetDimensionsCount(); - using Shape = std::vector; - Shape fOutputShape = outputDims == 2 ? Shape{numEdges, 1} : Shape{numEdges}; - std::vector fOutputTensor; - fOutputTensor.push_back(Ort::Value::CreateTensor( - memoryInfo, fOutputData.data(), fOutputData.size(), fOutputShape.data(), - fOutputShape.size())); - runSessionWithIoBinding(*m_model, fInputNames, fInputTensor, fOutputNames, - fOutputTensor); +template +auto torchToOnnx(Ort::MemoryInfo &memInfo, at::Tensor &tensor) { + std::vector shape{tensor.size(0), tensor.size(1)}; + return Ort::Value::CreateTensor(memInfo, tensor.data_ptr(), + tensor.numel(), shape.data(), + shape.size()); +} + +std::ostream &operator<<(std::ostream &os, Ort::Value &v) { + if (!v.IsTensor()) { + os << "no tensor"; + return os; + } + + auto shape = v.GetTensorTypeAndShapeInfo().GetShape(); + + auto printVal = [&]() { + for (int i = 0; i < shape.at(0); ++i) { + for (int j = 0; j < shape.at(1); ++j) { + os << v.At({i, j}) << " "; + } + os << "\n"; + } + }; + + auto type = v.GetTensorTypeAndShapeInfo().GetElementType(); + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + os << "[float tensor]\n"; + printVal.operator()(); + } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + os << "[int64 tensor]\n"; + printVal.operator()(); + } else { + os << "not implemented datatype"; + } + + return os; +} + +std::tuple +OnnxEdgeClassifier::operator()(std::any inputNodes, std::any inputEdges, + std::any inEdgeFeatures, torch::Device) { + auto torchDevice = torch::kCPU; + Ort::MemoryInfo memoryInfo("Cpu", OrtArenaAllocator, /*device_id*/ 0, + OrtMemTypeDefault); + + Ort::Allocator allocator(*m_model, memoryInfo); + + auto nodeTensor = + std::any_cast(inputNodes).to(torchDevice).clone(); + auto edgeList = std::any_cast(inputEdges).to(torchDevice); + const int numEdges = edgeList.size(1); + + std::vector inputNames{m_inputNames.at(0).c_str(), + m_inputNames.at(1).c_str()}; + + // TODO move this contiguous to graph construction + auto edgeListClone = edgeList.clone().contiguous(); + ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeListClone}); + auto nodeTensorClone = nodeTensor.clone(); + ACTS_DEBUG("nodes: " << detail::TensorDetails{nodeTensorClone}); + std::vector inputTensors; + inputTensors.push_back(torchToOnnx(memoryInfo, nodeTensorClone)); + inputTensors.push_back(torchToOnnx(memoryInfo, edgeListClone)); + + std::optional edgeAttrTensor; + if (inEdgeFeatures.has_value()) { + inputNames.push_back(m_inputNames.at(2).c_str()); + edgeAttrTensor = + std::any_cast(inEdgeFeatures).to(torchDevice).clone(); + inputTensors.push_back(torchToOnnx(memoryInfo, *edgeAttrTensor)); + } + + std::vector outputNames{m_outputName.c_str()}; + + auto outputTensor = + m_model->Run({}, inputNames.data(), inputTensors.data(), + inputTensors.size(), outputNames.data(), outputNames.size()); + + float *rawOutData = nullptr; + if (outputTensor.at(0).GetTensorTypeAndShapeInfo().GetElementType() == + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + rawOutData = outputTensor.at(0).GetTensorMutableData(); + } else { + throw std::runtime_error("Invalid output datatype"); + } ACTS_DEBUG("Get scores for " << numEdges << " edges."); - torch::Tensor edgeListCTen = torch::tensor(edgeList, {torch::kInt64}); - edgeListCTen = edgeListCTen.reshape({2, numEdges}); + auto scores = + torch::from_blob( + rawOutData, {numEdges}, + torch::TensorOptions().device(torchDevice).dtype(torch::kFloat32)) + .clone(); + + ACTS_VERBOSE("Slice of classified output before sigmoid:\n" + << scores.slice(/*dim=*/0, /*start=*/0, /*end=*/9)); + + scores.sigmoid_(); - torch::Tensor fOutputCTen = torch::tensor(fOutputData, {torch::kFloat32}); - fOutputCTen = fOutputCTen.sigmoid(); + ACTS_DEBUG("scores: " << detail::TensorDetails{scores}); + ACTS_VERBOSE("Slice of classified output:\n" + << scores.slice(/*dim=*/0, /*start=*/0, /*end=*/9)); - torch::Tensor filterMask = fOutputCTen > m_cfg.cut; - torch::Tensor edgesAfterFCTen = edgeListCTen.index({Slice(), filterMask}); + torch::Tensor filterMask = scores > m_cfg.cut; + torch::Tensor edgesAfterCut = edgeList.index({Slice(), filterMask}); - std::vector edgesAfterFiltering; - std::copy(edgesAfterFCTen.data_ptr(), - edgesAfterFCTen.data_ptr() + edgesAfterFCTen.numel(), - std::back_inserter(edgesAfterFiltering)); + ACTS_DEBUG("Finished edge classification, after cut: " + << edgesAfterCut.size(1) << " edges."); - std::int64_t numEdgesAfterF = edgesAfterFiltering.size() / 2; - ACTS_DEBUG("Finished edge classification, after cut: " << numEdgesAfterF - << " edges."); + if (edgesAfterCut.size(1) == 0) { + throw Acts::NoEdgesError{}; + } - return {std::make_shared(std::move(fInputTensor[0])), - edgesAfterFiltering, fOutputCTen}; + return {std::move(nodeTensor), edgesAfterCut.clone(), + std::move(inEdgeFeatures), std::move(scores)}; } } // namespace Acts diff --git a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp index 6187765b4ce..b81335ad0a9 100644 --- a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp @@ -8,6 +8,10 @@ #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp" +#include "Acts/Plugins/ExaTrkX/detail/Utils.hpp" + +#include + #ifndef ACTS_EXATRKX_CPUONLY #include #endif @@ -61,9 +65,12 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, TorchEdgeClassifier::~TorchEdgeClassifier() {} -std::tuple TorchEdgeClassifier::operator()( - std::any inputNodes, std::any inputEdges, torch::Device device) { - ACTS_DEBUG("Start edge classification"); +std::tuple +TorchEdgeClassifier::operator()(std::any inNodeFeatures, std::any inEdgeIndex, + std::any inEdgeFeatures, torch::Device device) { + decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4, t5; + t0 = std::chrono::high_resolution_clock::now(); + ACTS_DEBUG("Start edge classification, use " << device); c10::InferenceMode guard(true); // add a protection to avoid calling for kCPU @@ -74,51 +81,73 @@ std::tuple TorchEdgeClassifier::operator()( } #endif - auto nodes = std::any_cast(inputNodes).to(device); - auto edgeList = std::any_cast(inputEdges).to(device); + auto nodeFeatures = std::any_cast(inNodeFeatures).to(device); + auto edgeIndex = std::any_cast(inEdgeIndex).to(device); - auto model = m_model->clone(); - model.to(device); + if (edgeIndex.numel() == 0) { + throw NoEdgesError{}; + } - if (m_cfg.numFeatures > nodes.size(1)) { - throw std::runtime_error("requested more features then available"); + ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex}); + + std::optional edgeFeatures; + if (inEdgeFeatures.has_value()) { + edgeFeatures = std::any_cast(inEdgeFeatures).to(device); + ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{*edgeFeatures}); } + t1 = std::chrono::high_resolution_clock::now(); torch::Tensor output; // Scope this to keep inference objects separate { - auto edgeListTmp = m_cfg.undirected - ? torch::cat({edgeList, edgeList.flip(0)}, 1) - : edgeList; + auto edgeIndexTmp = m_cfg.undirected + ? torch::cat({edgeIndex, edgeIndex.flip(0)}, 1) + : edgeIndex; std::vector inputTensors(2); - inputTensors[0] = - m_cfg.numFeatures < nodes.size(1) - ? nodes.index({Slice{}, Slice{None, m_cfg.numFeatures}}) - : nodes; + auto selectedFeaturesTensor = + at::tensor(at::ArrayRef(m_cfg.selectedFeatures)); + at::Tensor selectedNodeFeatures = + !m_cfg.selectedFeatures.empty() + ? nodeFeatures.index({Slice{}, selectedFeaturesTensor}).clone() + : nodeFeatures; + + ACTS_DEBUG("selected nodeFeatures: " + << detail::TensorDetails{selectedNodeFeatures}); + inputTensors[0] = selectedNodeFeatures; + + if (edgeFeatures && m_cfg.useEdgeFeatures) { + inputTensors.push_back(*edgeFeatures); + } if (m_cfg.nChunks > 1) { std::vector results; results.reserve(m_cfg.nChunks); - auto chunks = at::chunk(edgeListTmp, m_cfg.nChunks, 1); + auto chunks = at::chunk(edgeIndexTmp, m_cfg.nChunks, 1); for (auto& chunk : chunks) { ACTS_VERBOSE("Process chunk with shape" << chunk.sizes()); inputTensors[1] = chunk; - results.push_back(model.forward(inputTensors).toTensor()); + results.push_back(m_model->forward(inputTensors).toTensor()); results.back().squeeze_(); } output = torch::cat(results); } else { - inputTensors[1] = edgeListTmp; - output = model.forward(inputTensors).toTensor(); + inputTensors[1] = edgeIndexTmp; + + t2 = std::chrono::high_resolution_clock::now(); + output = m_model->forward(inputTensors).toTensor().to(torch::kFloat32); + t3 = std::chrono::high_resolution_clock::now(); output.squeeze_(); } } + ACTS_VERBOSE("Slice of classified output before sigmoid:\n" + << output.slice(/*dim=*/0, /*start=*/0, /*end=*/9)); + output.sigmoid_(); if (m_cfg.undirected) { @@ -132,14 +161,23 @@ std::tuple TorchEdgeClassifier::operator()( printCudaMemInfo(logger()); torch::Tensor mask = output > m_cfg.cut; - torch::Tensor edgesAfterCut = edgeList.index({Slice(), mask}); + torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask}); edgesAfterCut = edgesAfterCut.to(torch::kInt64); ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1)); printCudaMemInfo(logger()); - - return {std::move(nodes), std::move(edgesAfterCut), - output.masked_select(mask)}; + t4 = std::chrono::high_resolution_clock::now(); + + auto milliseconds = [](const auto& a, const auto& b) { + return std::chrono::duration(b - a).count(); + }; + ACTS_DEBUG("Time anycast, device guard: " << milliseconds(t0, t1)); + ACTS_DEBUG("Time jit::IValue creation: " << milliseconds(t1, t2)); + ACTS_DEBUG("Time model forward: " << milliseconds(t2, t3)); + ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4)); + + return {std::move(nodeFeatures), std::move(edgesAfterCut), + std::move(inEdgeFeatures), output.masked_select(mask)}; } } // namespace Acts diff --git a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp index b65c551a63b..08088a5d904 100644 --- a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp @@ -15,6 +15,8 @@ #include #endif +#include + #include #include @@ -65,9 +67,9 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, TorchMetricLearning::~TorchMetricLearning() {} -std::tuple TorchMetricLearning::operator()( +std::tuple TorchMetricLearning::operator()( std::vector &inputValues, std::size_t numNodes, - torch::Device device) { + const std::vector & /*moduleIds*/, torch::Device device) { ACTS_DEBUG("Start graph construction"); c10::InferenceMode guard(true); @@ -105,18 +107,16 @@ std::tuple TorchMetricLearning::operator()( // Embedding // ********** - if (m_cfg.numFeatures > numAllFeatures) { - throw std::runtime_error("requested more features then available"); - } - // Clone models (solve memory leak? members can be const...) auto model = m_model->clone(); model.to(device); std::vector inputTensors; + auto selectedFeaturesTensor = + at::tensor(at::ArrayRef(m_cfg.selectedFeatures)); inputTensors.push_back( - m_cfg.numFeatures < numAllFeatures - ? inputTensor.index({Slice{}, Slice{None, m_cfg.numFeatures}}) + !m_cfg.selectedFeatures.empty() + ? inputTensor.index({Slice{}, selectedFeaturesTensor}) : std::move(inputTensor)); ACTS_DEBUG("embedding input tensor shape " @@ -141,6 +141,9 @@ std::tuple TorchMetricLearning::operator()( ACTS_VERBOSE("Slice of edgelist:\n" << edgeList.slice(1, 0, 5)); printCudaMemInfo(logger()); - return {std::move(inputTensors[0]).toTensor(), std::move(edgeList)}; + // TODO add real edge features for this workflow later + std::any edgeFeatures; + return {std::move(inputTensors[0]).toTensor(), std::move(edgeList), + std::move(edgeFeatures)}; } } // namespace Acts diff --git a/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp b/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp index ed766ffa16b..7873f9b56fd 100644 --- a/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp +++ b/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp @@ -9,11 +9,14 @@ #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp" #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp" +#include "Acts/Plugins/ExaTrkX/detail/Utils.hpp" #include #include +using namespace torch::indexing; + namespace { auto cantorize(std::vector edgeIndex, @@ -22,6 +25,7 @@ auto cantorize(std::vector edgeIndex, // operations to compute efficiency and purity std::vector> cantorEdgeIndex; cantorEdgeIndex.reserve(edgeIndex.size() / 2); + for (auto it = edgeIndex.begin(); it != edgeIndex.end(); it += 2) { cantorEdgeIndex.emplace_back(*it, *std::next(it)); } @@ -52,9 +56,27 @@ Acts::TorchTruthGraphMetricsHook::TorchTruthGraphMetricsHook( void Acts::TorchTruthGraphMetricsHook::operator()(const std::any&, const std::any& edges, const std::any&) const { + auto edgeIndexTensor = + std::any_cast(edges).to(torch::kCPU).contiguous(); + ACTS_VERBOSE("edge index tensor: " << detail::TensorDetails{edgeIndexTensor}); + + const auto numEdges = edgeIndexTensor.size(1); + if (numEdges == 0) { + ACTS_WARNING("no edges, cannot compute metrics"); + return; + } + ACTS_VERBOSE("Edge index slice:\n" + << edgeIndexTensor.index( + {Slice(0, 2), Slice(0, std::min(numEdges, 10l))})); + // We need to transpose the edges here for the right memory layout - const auto edgeIndex = Acts::detail::tensor2DToVector( - std::any_cast(edges).t()); + const auto edgeIndex = + Acts::detail::tensor2DToVector(edgeIndexTensor.t().clone()); + + ACTS_VERBOSE("Edge vector:\n" + << (detail::RangePrinter{ + edgeIndex.begin(), + edgeIndex.begin() + std::min(numEdges, 10l)})); auto predGraphCantor = cantorize(edgeIndex, logger()); diff --git a/Plugins/ExaTrkX/src/printCudaMemInfo.hpp b/Plugins/ExaTrkX/src/printCudaMemInfo.hpp index 64d5497dfc7..9e151fd0d98 100644 --- a/Plugins/ExaTrkX/src/printCudaMemInfo.hpp +++ b/Plugins/ExaTrkX/src/printCudaMemInfo.hpp @@ -22,7 +22,7 @@ namespace { inline void printCudaMemInfo(const Acts::Logger& logger) { #ifndef ACTS_EXATRKX_CPUONLY - if (torch::cuda::is_available()) { + if (torch::cuda::is_available() && logger.level() == Acts::Logging::VERBOSE) { constexpr float kb = 1024; constexpr float mb = kb * kb;