Skip to content

Commit

Permalink
Add a LinkNavigator utility (#646)
Browse files Browse the repository at this point in the history
* Add AssociationNavigator utility class and tests

* Rename Association to Link

* Default initialize the internal maps

* Make navigator work with links between same types

* Remove more mentions of assocations

* Avoid depending on an order in the linked objects

* Add documentation

* Make dependent types usable in c++17

* Documentation fixes as suggested

* Make directional overloads take a tag argument

Now all overloads have the same name, and selection is done via a tag
argument.

* Rename overload selection tags and refine docstring wording

* Update docs to new naming scheme

* Make sure documentation and code are consistent

Co-authored-by: Andre Sailer <[email protected]>

---------

Co-authored-by: Andre Sailer <[email protected]>
  • Loading branch information
tmadlener and andresailer authored Dec 5, 2024
1 parent 39587ee commit c97bdf1
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 1 deletion.
38 changes: 38 additions & 0 deletions doc/links.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,44 @@ and have that compiled into the library. This is necessary if you want to use
the python bindings, since they rely on dynamically loading the datamodel
libraries.

## The `LinkNavigator` utility

`podio::LinkCollection`s store each link separately even if a given object is
present in several links. Additionally, they don't offer any really easy way to
look up objects that are linked (apart from manually looping and comparing
elements). To alleviate these issues, we provide the `podio::LinkNavigator`
utility class that facilitates navigating links and lookups. It can be
constructed from any `podio::LinkCollection` and can then be used to retrieve
linked objects. E.g.

```cpp
const auto& recoMcLinks = event.get<edm4hep::RecoMCParticleLinkCollection>("RecoMCLinks");
const auto linkNavigator = podio::LinkNavigator(recoMcLinks);
// For podio::LinkCollections with disparate types just use getLinked
const auto linkedRecs = linkNavigator.getLinked(mcParticle);
```

If you want to be explicit about the lookup direction, e.g. in case you have a
link that has the same `From` and `To` type, you can use the overloads that take
a second *tag argument*:
```cpp
const auto linkedMCs = linkNavigator.getLinked(recoParticle, podio::ReturnTo);
```

The return type of all methods is a `std::vector<WeightedObject>`, where the
`WeightedObject` is a simple template class that wraps the object and its
weight. It supports structured bindings, so you can e.g. do the following

```cpp
for (const auto& [reco, weight] : linkedRecs) {
// do something with the reco particle and its weight
}
```

Alternatively, you can access the object via the `o` member and the weight via
the `weight` member.

## Implementation details

In order to give a slightly easier entry to the details of the implementation
Expand Down
171 changes: 171 additions & 0 deletions include/podio/LinkNavigator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#ifndef PODIO_LINKNAVIGATOR_H
#define PODIO_LINKNAVIGATOR_H

#include <map>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>

namespace podio {

namespace detail::links {
/// A small struct that simply bundles an object and its weight for a more
/// convenient return value for the LinkNavigator
///
/// @note In most uses the names of the members should not really matter as it
/// is possible to us this via structured bindings
template <typename T>
struct WeightedObject {
WeightedObject(T obj, float w) : o(obj), weight(w) {
}
T o; ///< The object
float weight; ///< The weight in the link

bool operator==(const WeightedObject<T>& other) const {
return other.o == o && other.weight == weight;
}
};

/// Simple struct tag for overload selection in LinkNavigator below
struct ReturnFromTag {};
/// Simple struct tag for overload selection in LinkNavigator below
struct ReturnToTag {};
} // namespace detail::links

/// Tag variable to select the lookup of *From* objects have links with a *To*
/// object in podio::LinkNavigator::getLinked
static constexpr detail::links::ReturnFromTag ReturnFrom;
/// Tag variable to select the lookup of *To* objects that have links with a
/// *From* object in podio::LinkNavigator::getLinked
static constexpr detail::links::ReturnToTag ReturnTo;

/// A helper class to more easily handle one-to-many links.
///
/// Internally simply populates two maps in its constructor and then queries
/// them to retrieve objects that are linked with another.
///
/// @note There are no guarantees on the order of the objects in these maps.
/// Hence, there are also no guarantees on the order of the returned objects,
/// even if there inherintly is an order to them in the underlying links
/// collection.
template <typename LinkCollT>
class LinkNavigator {
using FromT = typename LinkCollT::from_type;
using ToT = typename LinkCollT::to_type;

template <typename T>
using WeightedObject = detail::links::WeightedObject<T>;

public:
/// Construct a navigator from an link collection
LinkNavigator(const LinkCollT& links);

/// We do only construct from a collection
LinkNavigator() = delete;
LinkNavigator(const LinkNavigator&) = default;
LinkNavigator& operator=(const LinkNavigator&) = default;
LinkNavigator(LinkNavigator&&) = default;
LinkNavigator& operator=(LinkNavigator&&) = default;
~LinkNavigator() = default;

/// Get all the *From* objects and weights that have links with the passed
/// object
///
/// You will get this overload if you pass the podio::ReturnFrom tag as second
/// argument
///
/// @note This overload works always, even if the LinkCollection that was used
/// to construct this instance of the LinkNavigator has the same From and To
/// types.
///
/// @param object The object that is labeled *To* in the link
/// @param . tag variable for selecting this overload
///
/// @returns A vector of all objects and their weights that have links with
/// the passed object
std::vector<WeightedObject<FromT>> getLinked(const ToT& object, podio::detail::links::ReturnFromTag) const {
const auto& [begin, end] = m_to2from.equal_range(object);
std::vector<WeightedObject<FromT>> result;
result.reserve(std::distance(begin, end));

for (auto it = begin; it != end; ++it) {
result.emplace_back(it->second);
}
return result;
}

/// Get all the *From* objects and weights that have links with the passed
/// object
///
/// @note This overload will automatically do the right thing (TM) in case the
/// LinkCollection that has been passed to construct this LinkNavigator has
/// different From and To types.
///
/// @param object The object that is labeled *To* in the link
///
/// @returns A vector of all objects and their weights that have links with
/// the passed object
template <typename ToU = ToT>
std::enable_if_t<!std::is_same_v<FromT, ToU>, std::vector<WeightedObject<FromT>>> getLinked(const ToT& object) const {
return getLinked(object, podio::ReturnFrom);
}

/// Get all the *To* objects and weights that have links with the passed
/// object
///
/// You will get this overload if you pass the podio::ReturnTo tag as second
/// argument
///
/// @note This overload works always, even if the LinkCollection that was used
/// to construct this instance of the LinkNavigator has the same From and To
/// types.
///
/// @param object The object that is labeled *From* in the link
/// @param . tag variable for selecting this overload
///
/// @returns A vector of all objects and their weights that have links with
/// the passed object
std::vector<WeightedObject<ToT>> getLinked(const FromT& object, podio::detail::links::ReturnToTag) const {
const auto& [begin, end] = m_from2to.equal_range(object);
std::vector<WeightedObject<ToT>> result;
result.reserve(std::distance(begin, end));

for (auto it = begin; it != end; ++it) {
result.emplace_back(it->second);
}
return result;
}

/// Get all the *To* objects and weights that have links with the passed
/// object
///
/// @note This overload will automatically do the right thing (TM) in case the
/// LinkCollection that has been passed to construct this LinkNavigator has
/// different From and To types.
///
/// @param object The object that is labeled *From* in the link
///
/// @returns A vector of all objects and their weights that have links with
/// the passed object
template <typename FromU = FromT>
std::enable_if_t<!std::is_same_v<FromU, ToT>, std::vector<WeightedObject<ToT>>> getLinked(const FromT& object) const {
return getLinked(object, podio::ReturnTo);
}

private:
std::multimap<FromT, WeightedObject<ToT>> m_from2to{}; ///< Map the from to the to objects
std::multimap<ToT, WeightedObject<FromT>> m_to2from{}; ///< Map the to to the from objects
};

template <typename LinkCollT>
LinkNavigator<LinkCollT>::LinkNavigator(const LinkCollT& links) {
for (const auto& [from, to, weight] : links) {
m_from2to.emplace(std::piecewise_construct, std::forward_as_tuple(from), std::forward_as_tuple(to, weight));
m_to2from.emplace(std::piecewise_construct, std::forward_as_tuple(to), std::forward_as_tuple(from, weight));
}
}

} // namespace podio

#endif // PODIO_LINKNAVIGATOR_H
2 changes: 2 additions & 0 deletions include/podio/detail/LinkCollectionImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class LinkCollection : public podio::CollectionBase {
using CollectionDataT = podio::LinkCollectionData<FromT, ToT>;

public:
using from_type = FromT;
using to_type = ToT;
using value_type = Link<FromT, ToT>;
using mutable_type = MutableLink<FromT, ToT>;
using const_iterator = LinkCollectionIterator<FromT, ToT>;
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
if(CMAKE_CXX_STANDARD GREATER_EQUAL 20)
set(CATCH2_MIN_VERSION 3.4)
else()
set(CATCH2_MIN_VERSION 3.1)
set(CATCH2_MIN_VERSION 3.3)
endif()
if(USE_EXTERNAL_CATCH2)
if (USE_EXTERNAL_CATCH2 STREQUAL AUTO)
Expand Down
87 changes: 87 additions & 0 deletions tests/unittests/links.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "catch2/catch_test_macros.hpp"
#include "catch2/matchers/catch_matchers_vector.hpp"

#include "podio/LinkCollection.h"
#include "podio/LinkNavigator.h"

#include "datamodel/ExampleClusterCollection.h"
#include "datamodel/ExampleHitCollection.h"
Expand Down Expand Up @@ -473,3 +475,88 @@ TEST_CASE("Link JSON conversion", "[links][json]") {
}

#endif

TEST_CASE("LinkNavigator basics", "[links]") {
TestLColl coll{};
std::vector<ExampleHit> hits(11);
std::vector<ExampleCluster> clusters(3);

for (size_t i = 0; i < 10; ++i) {
auto a = coll.create();
a.set(hits[i]);
a.set(clusters[i % 3]);
a.setWeight(i * 0.1f);
}

auto a = coll.create();
a.set(hits[10]);

podio::LinkNavigator nav{coll};

for (size_t i = 0; i < 10; ++i) {
const auto& hit = hits[i];
const auto linkedClusters = nav.getLinked(hit);
REQUIRE(linkedClusters.size() == 1);
const auto& [cluster, weight] = linkedClusters[0];
REQUIRE(cluster == clusters[i % 3]);
REQUIRE(weight == i * 0.1f);
}

using Catch::Matchers::UnorderedEquals;
using podio::detail::links::WeightedObject;
using WeightedHits = std::vector<WeightedObject<ExampleHit>>;

auto linkedHits = nav.getLinked(clusters[0]);
REQUIRE_THAT(linkedHits,
UnorderedEquals(WeightedHits{WeightedObject{hits[0], 0.f}, WeightedObject{hits[3], 3 * 0.1f},
WeightedObject{hits[6], 6 * 0.1f}, WeightedObject{hits[9], 9 * 0.1f}}));

linkedHits = nav.getLinked(clusters[1]);
REQUIRE_THAT(linkedHits,
UnorderedEquals(WeightedHits{WeightedObject{hits[1], 0.1f}, WeightedObject{hits[4], 0.4f},
WeightedObject{hits[7], 0.7f}}));

const auto [noCluster, noWeight] = nav.getLinked(hits[10])[0];
REQUIRE_FALSE(noCluster.isAvailable());
}

TEST_CASE("LinkNavigator same types", "[links]") {
std::vector<ExampleCluster> clusters(3);
auto linkColl = podio::LinkCollection<ExampleCluster, ExampleCluster>{};
auto link = linkColl.create();
link.setFrom(clusters[0]);
link.setTo(clusters[1]);
link.setWeight(0.5f);

link = linkColl.create();
link.setFrom(clusters[0]);
link.setTo(clusters[2]);
link.setWeight(0.25f);

link = linkColl.create();
link.setFrom(clusters[1]);
link.setTo(clusters[2]);
link.setWeight(0.66f);

auto navigator = podio::LinkNavigator{linkColl};
auto linkedClusters = navigator.getLinked(clusters[1], podio::ReturnTo);
REQUIRE(linkedClusters.size() == 1);
REQUIRE(linkedClusters[0].o == clusters[2]);
REQUIRE(linkedClusters[0].weight == 0.66f);

linkedClusters = navigator.getLinked(clusters[1], podio::ReturnFrom);
REQUIRE(linkedClusters.size() == 1);
REQUIRE(linkedClusters[0].o == clusters[0]);
REQUIRE(linkedClusters[0].weight == 0.5f);

using Catch::Matchers::UnorderedEquals;
using podio::detail::links::WeightedObject;
using WeightedObjVec = std::vector<WeightedObject<ExampleCluster>>;
linkedClusters = navigator.getLinked(clusters[0], podio::ReturnTo);
REQUIRE_THAT(linkedClusters,
UnorderedEquals(WeightedObjVec{WeightedObject(clusters[1], 0.5f), WeightedObject{clusters[2], 0.25f}}));

linkedClusters = navigator.getLinked(clusters[2], podio::ReturnFrom);
REQUIRE_THAT(linkedClusters,
UnorderedEquals(WeightedObjVec{WeightedObject{clusters[0], 0.25f}, WeightedObject{clusters[1], 0.66f}}));
}

0 comments on commit c97bdf1

Please sign in to comment.