Skip to content

Commit

Permalink
Make it possible to use interface types in STL containers relying on …
Browse files Browse the repository at this point in the history
…comparisons (#552)

* Add missing final for override

* Switch to defaultdict

* Make it possible to store interface types in maps and sets
  • Loading branch information
tmadlener authored Feb 6, 2024
1 parent a25a6f6 commit 3d381ab
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 5 deletions.
33 changes: 29 additions & 4 deletions python/podio_gen/cpp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,14 @@ def __init__(
self.root_schema_component_names = set()
self.root_schema_datatype_names = set()
self.root_schema_iorules = set()
# a map of datatypes that are used in interfaces populated by pre_process
self.types_in_interfaces = {}

def pre_process(self):
"""The necessary specific pre-processing for cpp code generation"""
self._pre_process_schema_evolution()
self.types_in_interfaces = self._invert_interfaces()

return {}

def post_process(self, _):
Expand Down Expand Up @@ -120,6 +124,7 @@ def do_process_component(self, name, component):
def do_process_datatype(self, name, datatype):
"""Do the cpp specific processing of a datatype"""
datatype["includes_data"] = self._get_member_includes(datatype["Members"])
datatype["using_interface_types"] = self.types_in_interfaces.get(name, [])
self._preprocess_for_class(datatype)
self._preprocess_for_obj(datatype)
self._preprocess_for_collection(datatype)
Expand Down Expand Up @@ -201,7 +206,7 @@ def print_report(self):
def _preprocess_for_class(self, datatype):
"""Do the preprocessing that is necessary for the classes and Mutable classes"""
includes = set(datatype["includes_data"])
fwd_declarations = {}
fwd_declarations = defaultdict(list)
includes_cc = set()

for member in datatype["Members"]:
Expand All @@ -212,10 +217,8 @@ def _preprocess_for_class(self, datatype):
if self._is_interface(relation.full_type):
relation.interface_types = self.datamodel.interfaces[relation.full_type]["Types"]
if self._needs_include(relation.full_type):
if relation.namespace not in fwd_declarations:
fwd_declarations[relation.namespace] = []
fwd_declarations[relation.namespace].append(relation.bare_type)
fwd_declarations[relation.namespace].append("Mutable" + relation.bare_type)
fwd_declarations[relation.namespace].append(f"Mutable{relation.bare_type}")
includes_cc.add(self._build_include(relation))

if datatype["VectorMembers"] or datatype["OneToManyRelations"]:
Expand Down Expand Up @@ -246,6 +249,13 @@ def _preprocess_for_class(self, datatype):
except KeyError:
pass

# Make sure that all using interface types are properly forward declared
# to make it possible to declare them as friends so that they can access
# internals more easily
for interface in datatype["using_interface_types"]:
if_type = DataType(interface)
fwd_declarations[if_type.namespace].append(if_type.bare_type)

datatype["includes"] = self._sort_includes(includes)
datatype["includes_cc"] = self._sort_includes(includes_cc)
datatype["forward_declarations"] = fwd_declarations
Expand Down Expand Up @@ -381,6 +391,21 @@ def _pre_process_schema_evolution(self):
# add whatever is relevant to our ROOT schema evolution
self.root_schema_dict.setdefault(item.klassname, []).append(item)

def _invert_interfaces(self):
"""'Invert' the interfaces to have a mapping of types and their usage in
interfaces.
This is necessary to declare the interface types as friends of the
classes they wrap in order to more easily access some internals.
"""
types_in_interfaces = defaultdict(list)
for name, interface in self.datamodel.interfaces.items():
print(f"preprocessing interface {name}")
for if_type in interface["Types"]:
types_in_interfaces[if_type.full_type].append(name)

return types_in_interfaces

def _prepare_iorules(self):
"""Prepare the IORules to be put in the Reflex dictionary"""
for type_name, schema_changes in self.root_schema_dict.items():
Expand Down
11 changes: 10 additions & 1 deletion python/templates/Interface.h.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class {{ class.bare_type }} {
{{ macros.member_getters_concept(Members, use_get_syntax) }}
virtual const std::type_info& typeInfo() const = 0;
virtual bool equal(const Concept* rhs) const = 0;
virtual const void* objAddress() const = 0;
};

template<typename ValueT>
Expand All @@ -65,7 +66,7 @@ class {{ class.bare_type }} {

void unlink() final { m_value.unlink(); }
bool isAvailable() const final { return m_value.isAvailable(); }
podio::ObjectID getObjectID() const { return m_value.getObjectID(); }
podio::ObjectID getObjectID() const final { return m_value.getObjectID(); }

const std::type_info& typeInfo() const final { return typeid(ValueT); }

Expand All @@ -76,6 +77,10 @@ class {{ class.bare_type }} {
return false;
}

const void* objAddress() const final {
return m_value.m_obj.get();
}

{{ macros.member_getters_model(Members, use_get_syntax) }}

ValueT m_value{};
Expand Down Expand Up @@ -144,6 +149,10 @@ public:
return !(lhs == rhs);
}

friend bool operator<(const {{ class.bare_type }}& lhs, const {{ class.bare_type }}& rhs) {
return lhs.m_self->objAddress() < rhs.m_self->objAddress();
}

{{ macros.member_getters(Members, use_get_syntax) }}

friend std::ostream& operator<<(std::ostream& os, const {{ class.bare_type }}& value) {
Expand Down
3 changes: 3 additions & 0 deletions python/templates/Object.h.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class {{ class.bare_type }} {
friend class {{ class.bare_type }}Collection;
friend class {{ class.full_type }}CollectionData;
friend class {{ class.bare_type }}CollectionIterator;
{% for interface in using_interface_types %}
friend class {{ interface }};
{% endfor %}

public:
using mutable_type = Mutable{{ class.bare_type }};
Expand Down
21 changes: 21 additions & 0 deletions tests/unittests/interface_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "datamodel/ExampleClusterCollection.h"
#include "datamodel/ExampleHitCollection.h"
#include "datamodel/TypeWithEnergy.h"

#include <map>
#include <stdexcept>

TEST_CASE("InterfaceTypes basic functionality", "[interface-types][basics]") {
Expand Down Expand Up @@ -45,6 +47,25 @@ TEST_CASE("InterfaceTypes basic functionality", "[interface-types][basics]") {
REQUIRE(wrapper1.id() == podio::ObjectID{0, 42});
}

TEST_CASE("InterfaceTypes STL usage", "[interface-types][basics]") {
// Make sure that interface types can be used with STL map and set
std::map<TypeWithEnergy, int> counterMap{};

auto empty = TypeWithEnergy::makeEmpty();
counterMap[empty]++;

ExampleHit hit{};
auto wrapper = TypeWithEnergy{hit};
counterMap[wrapper]++;

// No way this implicit conversion could ever lead to a subtle bug ;)
counterMap[hit]++;

REQUIRE(counterMap[empty] == 1);
REQUIRE(counterMap[hit] == 2);
REQUIRE(counterMap[wrapper] == 2);
}

TEST_CASE("InterfaceType from immutable", "[interface-types][basics]") {
using WrapperT = TypeWithEnergy;

Expand Down

0 comments on commit 3d381ab

Please sign in to comment.