diff --git a/include/podio/DataSource.h b/include/podio/DataSource.h index ba00fe4e4..d13f78487 100644 --- a/include/podio/DataSource.h +++ b/include/podio/DataSource.h @@ -23,12 +23,25 @@ class DataSource : public ROOT::RDF::RDataSource { /// /// @brief Construct the podio::DataSource from the provided file. /// - explicit DataSource(const std::string& filePath, int nEvents = -1); + /// @param filePath Path to the file that should be read + /// @param nEvents Number of events to process (optional, defaults to -1 for + /// all events) + /// @param collsToRead The collections that should be made available (optional, + /// defaults to empty vector for all collections) + /// + explicit DataSource(const std::string& filePath, int nEvents = -1, const std::vector& collsToRead = {}); /// /// @brief Construct the podio::DataSource from the provided file list. /// - explicit DataSource(const std::vector& filePathList, int nEvents = -1); + /// @param filePathList Paths to the files that should be read + /// @param nEvents Number of events to process (optional, defaults to -1 for + /// all events) + /// @param collsToRead The collections that should be made available (optional, + /// defaults to empty vector for all collections) + /// + explicit DataSource(const std::vector& filePathList, int nEvents = -1, + const std::vector& collsToRead = {}); /// /// @brief Inform the podio::DataSource of the desired level of parallelism. @@ -139,7 +152,7 @@ class DataSource : public ROOT::RDF::RDataSource { /// /// @param[in] nEvents Number of events. /// - void SetupInput(int nEvents); + void SetupInput(int nEvents, const std::vector& collsToRead); }; /// @@ -147,17 +160,24 @@ class DataSource : public ROOT::RDF::RDataSource { /// /// @param[in] filePathList List of file paths from which the RDataFrame /// will be created. +/// @param[in] collsToRead List of collection names that should be made +/// available +/// /// @return RDataFrame created from input file list. /// -ROOT::RDataFrame CreateDataFrame(const std::vector& filePathList); +ROOT::RDataFrame CreateDataFrame(const std::vector& filePathList, + const std::vector& collsToRead = {}); /// /// @brief Create RDataFrame from a Podio file. /// /// @param[in] filePath File path from which the RDataFrame will be created. +/// @param[in] collsToRead List of collection names that should be made +/// available +/// /// @return RDataFrame created from input file list. /// -ROOT::RDataFrame CreateDataFrame(const std::string& filePath); +ROOT::RDataFrame CreateDataFrame(const std::string& filePath, const std::vector& collsToRead = {}); } // namespace podio #endif /* PODIO_DATASOURCE_H */ diff --git a/src/DataSource.cc b/src/DataSource.cc index 4b8fac0af..05334a2bd 100644 --- a/src/DataSource.cc +++ b/src/DataSource.cc @@ -13,17 +13,19 @@ #include namespace podio { -DataSource::DataSource(const std::string& filePath, int nEvents) : m_nSlots{1} { +DataSource::DataSource(const std::string& filePath, int nEvents, const std::vector& collsToRead) : + m_nSlots{1} { m_filePathList.emplace_back(filePath); - SetupInput(nEvents); + SetupInput(nEvents, collsToRead); } -DataSource::DataSource(const std::vector& filePathList, int nEvents) : +DataSource::DataSource(const std::vector& filePathList, int nEvents, + const std::vector& collNames) : m_nSlots{1}, m_filePathList{filePathList} { - SetupInput(nEvents); + SetupInput(nEvents, collNames); } -void DataSource::SetupInput(int nEvents) { +void DataSource::SetupInput(int nEvents, const std::vector& collsToRead) { if (m_filePathList.empty()) { throw std::runtime_error("podio::DataSource: No input files provided!"); } @@ -36,7 +38,7 @@ void DataSource::SetupInput(int nEvents) { unsigned int nEventsInFiles = 0; auto podioReader = podio::makeReader(m_filePathList); nEventsInFiles = podioReader.getEntries(podio::Category::Event); - frame = podioReader.readFrame(podio::Category::Event, 0); + frame = podioReader.readFrame(podio::Category::Event, 0, collsToRead); // Determine over how many events to run if (nEventsInFiles <= 0) { @@ -173,14 +175,15 @@ std::string DataSource::GetTypeName(std::string_view columnName) const { return m_columnTypes.at(typeIndex); } -ROOT::RDataFrame CreateDataFrame(const std::vector& filePathList) { - ROOT::RDataFrame rdf(std::make_unique(filePathList)); +ROOT::RDataFrame CreateDataFrame(const std::vector& filePathList, + const std::vector& collsToRead) { + ROOT::RDataFrame rdf(std::make_unique(filePathList, -1, collsToRead)); return rdf; } -ROOT::RDataFrame CreateDataFrame(const std::string& filePath) { - ROOT::RDataFrame rdf(std::make_unique(filePath)); +ROOT::RDataFrame CreateDataFrame(const std::string& filePath, const std::vector& collsToRead) { + ROOT::RDataFrame rdf(std::make_unique(filePath, -1, collsToRead)); return rdf; } diff --git a/tests/root_io/read_with_rdatasource_root.cpp b/tests/root_io/read_with_rdatasource_root.cpp index 29bad4319..7ef70a4ed 100644 --- a/tests/root_io/read_with_rdatasource_root.cpp +++ b/tests/root_io/read_with_rdatasource_root.cpp @@ -1,6 +1,8 @@ #include "datamodel/ExampleClusterCollection.h" #include "podio/DataSource.h" +#include "podio/Reader.h" +#include #include #include @@ -28,8 +30,40 @@ int main(int argc, const char* argv[]) { dframe.Describe().Print(); std::cout << std::endl; + const auto expectedCollNames = [&inputFile]() { + auto reader = podio::makeReader(inputFile); + auto cols = reader.readNextEvent().getAvailableCollections(); + std::ranges::sort(cols); + return cols; + }(); + const auto allColNames = [&dframe]() { + auto cols = dframe.GetColumnNames(); + std::ranges::sort(cols); + return cols; + }(); + + if (!std::ranges::equal(expectedCollNames, allColNames)) { + std::cerr << "Column names are note as expected\n expected: ["; + for (const auto& name : expectedCollNames) { + std::cerr << name << " "; + } + std::cerr << "]\n actual: ["; + for (const auto& name : allColNames) { + std::cerr << name << " "; + } + std::cerr << "]" << std::endl; + + return EXIT_FAILURE; + } + auto cluterEnergy = dframe.Define("cluster_energy", getEnergy, {"clusters"}).Histo1D("cluster_energy"); cluterEnergy->Print(); + dframe = podio::CreateDataFrame(inputFile, {"hits"}); + if (dframe.GetColumnNames()[0] != "hits") { + std::cerr << "Limiting to only one collection didn't work as expected" << std::endl; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; }