Skip to content

Commit

Permalink
Clean up 1
Browse files Browse the repository at this point in the history
  • Loading branch information
fredevb committed Jun 20, 2024
1 parent 93a67bf commit e647915
Show file tree
Hide file tree
Showing 12 changed files with 474 additions and 204 deletions.
4 changes: 4 additions & 0 deletions Examples/Algorithms/Traccc/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
add_library(
ActsExamplesTracccCommon SHARED
src/TracccChainAlgorithmBase.cpp
src/Conversion/CellMapConversion.cpp
src/Conversion/DigitizationConversion.cpp
src/Conversion/MeasurementConversion.cpp
src/Debug/Debug.cpp
)

target_include_directories(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@
// Traccc include(s)
#include "traccc/edm/cell.hpp"

// VecMem include(s).
#include <vecmem/memory/memory_resource.hpp>

// Boost include(s)
#include <boost/range/combine.hpp>

// System include(s).
#include <cstdint>
#include <cstdlib>
Expand All @@ -31,65 +25,10 @@

namespace ActsExamples::Traccc::Common::Conversion {

/// @brief Gets the time of the cell.
/// @note Currently, it always returns 0.
inline float getTime(const Cluster::Cell& /*cell*/){
return 0.f;
}

/// @brief Gets the activation of the cell.
inline float getActivation(const Cluster::Cell& cell){
return static_cast<float>(cell.activation);
}

/// @brief Gets the row of the cell.
inline unsigned int getRow(const Cluster::Cell& cell){
if (cell.bin[0] > UINT_MAX) {
throw std::runtime_error("Overflow will occur when casting to unsigned int.");
}
return static_cast<unsigned int>(cell.bin[0]);
}

/// @brief Gets the column of the cell.
inline unsigned int getColumn(const Cluster::Cell& cell){
if (cell.bin[0] > UINT_MAX) {
throw std::runtime_error("Overflow will occur when casting to unsigned int.");
}
return static_cast<unsigned int>(cell.bin[1]);
}

/// @brief Creates a traccc cell from a generic cell type.
/// @param cell the generic cell.
/// @param moduleLink the module link value to set for the traccc cell that is created.
/// @returns a traccc cell.
/// @note the functions getRow(cell_t), getColumn(cell_t), getActivation(cell_t), getTime(cell_t) are expected.
template <typename cell_t>
auto tracccCell(const cell_t& cell, const traccc::cell::link_type moduleLink = 0){
return traccc::cell{
getRow(cell),
getColumn(cell),
getActivation(cell),
getTime(cell),
moduleLink
};
}

/// @brief Converts a "geometry ID -> generic cell collection type" map to a "geometry ID -> traccc cell collection" map.
/// @note The function sets the module link of the cells in the output to 0.
/// @return Map from geometry ID to its cell data (as a vector of traccc cell data)
template <typename cell_collection_t>
inline std::map<std::uint64_t, std::vector<traccc::cell>> tracccCellsMap(
const std::map<Acts::GeometryIdentifier, cell_collection_t>& map)
{
std::map<std::uint64_t, std::vector<traccc::cell>> tracccCellMap;
for (const auto& [geometryID, cells] : map){
std::vector<traccc::cell> tracccCells;
for (const auto& cell : cells){
tracccCells.push_back(tracccCell(cell));
}
tracccCellMap.insert({geometryID.value(), std::move(tracccCells)});
}
return tracccCellMap;
}
std::map<std::uint64_t, std::vector<traccc::cell>> tracccCellsMap(
const std::map<Acts::GeometryIdentifier, std::vector<Cluster::Cell>>& map);

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,20 @@

// Acts include(s)
#include "Acts/Geometry/GeometryHierarchyMap.hpp"
#include "Acts/Utilities/BinUtility.hpp"

// Acts Examples include(s)
#include "ActsExamples/Digitization/DigitizationConfig.hpp"

// Traccc include(s)
#include "traccc/io/digitization_config.hpp"

// System include(s).
#include <cstdint>
#include <cstdlib>
#include <vector>
#include <map>

namespace ActsExamples::Traccc::Common::Conversion {

/// @brief Get the segmentation from a DigiComponentsConfig.
inline Acts::BinUtility getSegmentation(const DigiComponentsConfig& dcc){
return dcc.geometricDigiConfig.segmentation;
}

/// @brief Creates a traccc digitalization config from an Acts geometry hierarchy map
/// that contains the digitization configuration.
/// @param config the Acts geometry hierarchy map that contains the digitization configuration.
/// @return a traccc digitization config.
template <typename data_t>
inline traccc::digitization_config tracccConfig(
const Acts::GeometryHierarchyMap<data_t>& config){
using ElementType = std::pair<Acts::GeometryIdentifier, traccc::module_digitization_config>;
std::vector<ElementType> vec;
for (auto& e : config.getElements()){
vec.push_back({e.first, traccc::module_digitization_config{getSegmentation(e.second)}});
}
return traccc::digitization_config(vec);
}
traccc::digitization_config tracccConfig(
const Acts::GeometryHierarchyMap<DigiComponentsConfig>& config);

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,149 @@
#pragma once

// Plugin include(s)
#include "Acts/Plugins/Traccc/MeasurementConversion.hpp"
#include "Acts/Plugins/Traccc/Detail/AlgebraConversion.hpp"

// Acts include(s)
#include "Acts/Definitions/Algebra.hpp"
#include "Acts/Definitions/TrackParametrization.hpp"
#include "Acts/EventData/Measurement.hpp"
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/Geometry/GeometryIdentifier.hpp"

// Acts examples include(s)
// Acts Examples include(s)
#include "ActsExamples/EventData/IndexSourceLink.hpp"

// Detray include(s)
#include "detray/core/detector.hpp"
#include "detray/tracks/bound_track_parameters.hpp"

// Traccc include(s)
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/definitions/track_parametrization.hpp"
#include "traccc/edm/measurement.hpp"
#include "traccc/edm/track_state.hpp"

// Detray include(s).
#include "detray/core/detector.hpp"

// System include(s).
// System include(s)
#include <memory>
#include <variant>
#include <cstdint>
#include <cstdlib>
#include <vector>

namespace ActsExamples::Traccc::Common::Conversion {

/// @brief Converts a traccc bound index to an Acts bound index.
/// @param tracccBoundIndex the traccc bound index.
/// @returns an Acts bound index.
Acts::BoundIndices boundIndex(
const traccc::bound_indices tracccBoundIndex);

/// @brief Creates an Acts measurement from a traccc measurement.
/// @tparam the dimension of the Acts measurement (subspace size).
/// @param m the traccc measurement.
/// @param sl the Acts source link to use for the Acts measurement.
/// @returns an Acts measurement with data copied from the traccc measurement
/// and with its source link set to the one provided to the function.
template <std::size_t dim>
inline Acts::Measurement<Acts::BoundIndices, dim> measurement(
const traccc::measurement& m, const Acts::SourceLink sl) {
auto params = Acts::TracccPlugin::detail::toActsVector<dim>(m.local);
std::array<Acts::BoundIndices, dim> indices;
for (unsigned int i = 0; i < dim; i++) {
indices[i] = boundIndex(traccc::bound_indices(m.subs.get_indices()[i]));
}
auto cov = Eigen::DiagonalMatrix<Acts::ActsScalar, static_cast<int>(dim)>(
Acts::TracccPlugin::detail::toActsVector<dim>(m.variance))
.toDenseMatrix();
return Acts::Measurement<Acts::BoundIndices, dim>(std::move(sl), indices,
params, cov);
}

/// @brief Creates an Acts bound variant measurement from a traccc measurement.
/// Using recursion, the functions determines the dimension of the traccc
/// measurement which is used for the Acts measurement that the bound variant
/// measurement holds. The dimension must lie between [0; max_dim].
/// @tparam max_dim the largest possible dimension of any measurement type in the variant (default = 4).
/// @param m the traccc measurement.
/// @param sl the Acts source link to use for the Acts measurement.
/// @returns an Acts bound variant measurement with data copied from the traccc measurement
/// and with its source link set to the one provided to the function.
template <std::size_t max_dim = 4UL>
inline Acts::BoundVariantMeasurement boundVariantMeasurement(
const traccc::measurement& m, const Acts::SourceLink sl) {
if constexpr (max_dim == 0UL) {
std::string errorMsg = "Invalid/mismatching measurement dimension: " +
std::to_string(m.meas_dim);
throw std::runtime_error(errorMsg.c_str());
} else {
if (m.meas_dim == max_dim) {
return measurement<max_dim>(m, sl);
}
return boundVariantMeasurement<max_dim - 1>(m, sl);
}
}

/// @brief Gets the the local position of the measurement.
/// @param measurement the Acts measurement.
/// @returns A two-dimensional vector containing the local position.
/// The first item in the vector is local position on axis 0 and
/// I.e., [local position (axis 0), local position (axis 1)].
/// @note if the dimension is less than 2 then the remaining values are set to 0.
template <std::size_t dim>
inline Acts::ActsVector<2> getLocal(
const Acts::Measurement<Acts::BoundIndices, dim>& measurement) {
traccc::scalar loc0 = 0;
traccc::scalar loc1 = 0;
if constexpr (dim > Acts::BoundIndices::eBoundLoc0) {
loc0 = measurement.parameters()(Acts::BoundIndices::eBoundLoc0);
}
if constexpr (dim > Acts::BoundIndices::eBoundLoc1) {
loc1 = measurement.parameters()(Acts::BoundIndices::eBoundLoc1);
}
return Acts::ActsVector<2>(loc0, loc1);
}

/// @brief Get the the local position of the measurement.
/// @param measurement the Acts bound variant measurement.
/// @return A two-dimensional vector containing the local position.
/// I.e., [local position (axis 0), local position (axis 1)].
/// @note if the dimension is less than 2 then the remaining values are set to 0.
inline Acts::ActsVector<2> getLocal(
const Acts::BoundVariantMeasurement& measurement) {
return std::visit([](auto& m) { return getLocal(m); }, measurement);
}

/// @brief Get the the variance of the measurement.
/// @param measurement the Acts measurement.
/// @return A two-dimensional vector containing the variance.
/// I.e., [variance (axis 0), variance (axis 1)].
/// @note if the dimension is less than 2 then the remaining values are set to 0.
template <std::size_t dim>
inline Acts::ActsVector<2> getVariance(
const Acts::Measurement<Acts::BoundIndices, dim>& measurement) {
traccc::scalar var0 = 0;
traccc::scalar var1 = 0;
if constexpr (dim >= Acts::BoundIndices::eBoundLoc0) {
var0 = measurement.covariance()(Acts::BoundIndices::eBoundLoc0,
Acts::BoundIndices::eBoundLoc0);
}
if constexpr (dim > Acts::BoundIndices::eBoundLoc1) {
var1 = measurement.covariance()(Acts::BoundIndices::eBoundLoc1,
Acts::BoundIndices::eBoundLoc1);
}
return Acts::ActsVector<2>(var0, var1);
}

/// @brief Get the the variance of the measurement.
/// @param measurement the Acts bound variant measurement.
/// @return A two-dimensional vector containing the variance.
/// I.e., [variance (axis 0), variance (axis 1)].
/// @note if the dimension is less than 2 then the remaining values are set to 0.
inline Acts::ActsVector<2> getVariance(
const Acts::BoundVariantMeasurement& measurement) {
return std::visit([](auto& m) { return getVariance(m); }, measurement);
}

/// @brief Converts traccc measurements to acts measurements.
/// @param detector The detray detector,
/// @param measurements The traccc measurements,
Expand All @@ -43,7 +165,7 @@ inline auto createActsMeasurements(const detector_t& detector, const std::vector
Acts::GeometryIdentifier moduleGeoId(detector.surface(m.surface_link).source);
Index measurementIdx = measurementContainer.size();
IndexSourceLink idxSourceLink{moduleGeoId, measurementIdx};
measurementContainer.push_back(Acts::TracccPlugin::boundVariantMeasurement(m, Acts::SourceLink{idxSourceLink}));
measurementContainer.push_back(boundVariantMeasurement(m, Acts::SourceLink{idxSourceLink}));
}
return measurementContainer;
}
Expand Down
Loading

0 comments on commit e647915

Please sign in to comment.