diff --git a/python/podio/base_writer.py b/python/podio/base_writer.py new file mode 100644 index 000000000..6f2b5777a --- /dev/null +++ b/python/podio/base_writer.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +"""Python module for defining the basic writer interface that is used by the +backend specific bindings""" + + +class BaseWriterMixin: + """Mixin class that defines the base interface of the writers. + + The backend specific writers inherit from here and have to initialize the + following members: + - _writer: The actual writer that is able to write frames + """ + + def write_frame(self, frame, category, collections=None): + """Write the given frame under the passed category, optionally limiting the + collections that are written. + + Args: + frame (podio.frame.Frame): The Frame to write + category (str): The category name + collections (optional, default=None): The subset of collections to + write. If None, all collections are written + """ + # pylint: disable-next=protected-access + self._writer.writeFrame(frame._frame, category, collections or frame.collections) diff --git a/python/podio/frame.py b/python/podio/frame.py index 4822b00df..1a69d7a46 100644 --- a/python/podio/frame.py +++ b/python/podio/frame.py @@ -48,16 +48,38 @@ def _determine_cpp_type(idx_and_type): SUPPORTED_PARAMETER_TYPES = _determine_supported_parameter_types() -def _get_cpp_vector_types(type_str): - """Get the possible std::vector from the passed py_type string.""" - # Gather a list of all types that match the type_str (c++ or python) +def _get_cpp_types(type_str): + """Get all possible c++ types from the passed py_type string.""" types = list(filter(lambda t: type_str in t, SUPPORTED_PARAMETER_TYPES)) if not types: raise ValueError(f'{type_str} cannot be mapped to a valid parameter type') + return types + + +def _get_cpp_vector_types(type_str): + """Get the possible std::vector from the passed py_type string.""" + # Gather a list of all types that match the type_str (c++ or python) + types = _get_cpp_types(type_str) return [f'std::vector<{t}>' for t in map(lambda x: x[0], types)] +def _is_collection_base(thing): + """Check whether the passed thing is a podio::CollectionBase + + Args: + thing (any): any object + + Returns: + bool: True if thing is a base of podio::CollectionBase, False otherwise + """ + # Make sure to only instantiate the template with things that cppyy + # understands + if "cppyy" in repr(thing): + return cppyy.gbl.std.is_base_of[cppyy.gbl.podio.CollectionBase, type(thing)].value + return False + + class Frame: """Frame class that serves as a container of collection and meta data.""" @@ -78,17 +100,16 @@ def __init__(self, data=None): else: self._frame = podio.Frame() - self._collections = tuple(str(s) for s in self._frame.getAvailableCollections()) - self._param_key_types = self._init_param_keys() + self._param_key_types = self._get_param_keys_types() @property def collections(self): - """Get the available collection (names) from this Frame. + """Get the currently available collection (names) from this Frame. Returns: tuple(str): The names of the available collections from this Frame. """ - return self._collections + return tuple(str(s) for s in self._frame.getAvailableCollections()) def get(self, name): """Get a collection from the Frame by name. @@ -107,9 +128,32 @@ def get(self, name): raise KeyError(f"Collection '{name}' is not available") return collection + def put(self, collection, name): + """Put the collection into the frame + + The passed collectoin is "moved" into the Frame, i.e. it cannot be used any + longer after a call to this function. This also means that only objects that + were in the collection at the time of calling this function will be + available afterwards. + + Args: + collection (podio.CollectionBase): The collection to put into the Frame + name (str): The name of the collection + + Returns: + podio.CollectionBase: The reference to the collection that has been put + into the Frame. NOTE: That mutating this collection is not allowed. + + Raises: + ValueError: If collection is not actually a podio.CollectionBase + """ + if not _is_collection_base(collection): + raise ValueError("Can only put podio collections into a Frame") + return self._frame.put(cppyy.gbl.std.move(collection), name) + @property def parameters(self): - """Get the available parameter names from this Frame. + """Get the currently available parameter names from this Frame. Returns: tuple (str): The names of the available parameters from this Frame. @@ -163,6 +207,58 @@ def _get_param_value(par_type, name): return _get_param_value(vec_types[0], name) + def put_parameter(self, key, value, as_type=None): + """Put a parameter into the Frame. + + Puts a parameter into the Frame after doing some (incomplete) type checks. + If a list is passed the parameter type is determined from looking at the + first element of the list only. Additionally, since python doesn't + differentiate between floats and doubles, floats will always be stored as + doubles by default, use the as_type argument to change this if necessary. + + Args: + key (str): The name of the parameter + value (int, float, str or list of these): The parameter value + as_type (str, optional): Explicitly specify the type that should be used + to put the parameter into the Frame. Python types (e.g. "str") will + be converted to c++ types. This will override any automatic type + deduction that happens otherwise. Note that this will be taken at + pretty much face-value and there are only limited checks for this. + + Raises: + ValueError: If a non-supported parameter type is passed + """ + # For lists we determine the c++ vector type and use that to call the + # correct template overload explicitly + if isinstance(value, (list, tuple)): + type_name = as_type or type(value[0]).__name__ + vec_types = _get_cpp_vector_types(type_name) + if len(vec_types) == 0: + raise ValueError(f"Cannot put a parameter of type {type_name} into a Frame") + + par_type = vec_types[0] + if isinstance(value[0], float): + # Always store floats as doubles from the python side + par_type = par_type.replace("float", "double") + + self._frame.putParameter[par_type](key, value) + else: + if as_type is not None: + cpp_types = _get_cpp_types(as_type) + if len(cpp_types) == 0: + raise ValueError(f"Cannot put a parameter of type {as_type} into a Frame") + self._frame.putParameter[cpp_types[0]](key, value) + + # If we have a single integer, a std::string overload kicks in with higher + # priority than the template for some reason. So we explicitly select the + # correct template here + elif isinstance(value, int): + self._frame.putParameter["int"](key, value) + else: + self._frame.putParameter(key, value) + + self._param_key_types = self._get_param_keys_types() # refresh the cache + def get_parameters(self): """Get the complete podio::GenericParameters object stored in this Frame. @@ -200,7 +296,7 @@ def get_param_info(self, name): return par_infos - def _init_param_keys(self): + def _get_param_keys_types(self): """Initialize the param keys dict for easier lookup of the available parameters. Returns: diff --git a/python/podio/root_io.py b/python/podio/root_io.py index a5f25950e..9623ee24d 100644 --- a/python/podio/root_io.py +++ b/python/podio/root_io.py @@ -6,8 +6,7 @@ from ROOT import podio # noqa: E402 # pylint: disable=wrong-import-position from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position - -Writer = podio.ROOTFrameWriter +from podio.base_writer import BaseWriterMixin # pylint: disable=wrong-import-position class Reader(BaseReaderMixin): @@ -49,3 +48,14 @@ def __init__(self, filenames): self._is_legacy = True super().__init__() + + +class Writer(BaseWriterMixin): + """Writer class for writing podio root files""" + def __init__(self, filename): + """Create a writer for writing files + + Args: + filename (str): The name of the output file + """ + self._writer = podio.ROOTFrameWriter(filename) diff --git a/python/podio/sio_io.py b/python/podio/sio_io.py index 01f9d577f..30257a860 100644 --- a/python/podio/sio_io.py +++ b/python/podio/sio_io.py @@ -9,8 +9,7 @@ from ROOT import podio # noqa: 402 # pylint: disable=wrong-import-position from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position - -Writer = podio.SIOFrameWriter +from podio.base_writer import BaseWriterMixin # pylint: disable=wrong-import-position class Reader(BaseReaderMixin): @@ -46,3 +45,14 @@ def __init__(self, filename): self._is_legacy = True super().__init__() + + +class Writer(BaseWriterMixin): + """Writer class for writing podio root files""" + def __init__(self, filename): + """Create a writer for writing files + + Args: + filename (str): The name of the output file + """ + self._writer = podio.SIOFrameWriter(filename) diff --git a/python/podio/test_Frame.py b/python/podio/test_Frame.py index f8ec1ad96..c09143056 100644 --- a/python/podio/test_Frame.py +++ b/python/podio/test_Frame.py @@ -7,6 +7,8 @@ # using root_io as that should always be present regardless of which backends are built from podio.root_io import Reader +from podio.test_utils import ExampleHitCollection + # The expected collections in each frame EXPECTED_COLL_NAMES = { 'arrays', 'WithVectorMember', 'info', 'fixedWidthInts', 'mcparticles', @@ -34,6 +36,63 @@ def test_frame_invalid_access(self): with self.assertRaises(KeyError): _ = frame.get_parameter('NonExistantParameter') + with self.assertRaises(ValueError): + collection = [1, 2, 4] + _ = frame.put(collection, "invalid_collection_type") + + def test_frame_put_collection(self): + """Check that putting a collection works as expected""" + frame = Frame() + self.assertEqual(frame.collections, tuple()) + + hits = ExampleHitCollection() + hits.create() + hits2 = frame.put(hits, "hits_from_python") + self.assertEqual(frame.collections, tuple(["hits_from_python"])) + # The original collection is gone at this point, and ideally just leaves an + # empty shell + self.assertEqual(len(hits), 0) + # On the other hand the return value of put has the original content + self.assertEqual(len(hits2), 1) + + def test_frame_put_parameters(self): + """Check that putting a parameter works as expected""" + frame = Frame() + self.assertEqual(frame.parameters, tuple()) + + frame.put_parameter("a_string_param", "a string") + self.assertEqual(frame.parameters, tuple(["a_string_param"])) + self.assertEqual(frame.get_parameter("a_string_param"), "a string") + + frame.put_parameter("float_param", 3.14) + self.assertEqual(frame.get_parameter("float_param"), 3.14) + + frame.put_parameter("int", 42) + self.assertEqual(frame.get_parameter("int"), 42) + + frame.put_parameter("string_vec", ["a", "b", "cd"]) + str_vec = frame.get_parameter("string_vec") + self.assertEqual(len(str_vec), 3) + self.assertEqual(str_vec, ["a", "b", "cd"]) + + frame.put_parameter("more_ints", [1, 2345]) + int_vec = frame.get_parameter("more_ints") + self.assertEqual(len(int_vec), 2) + self.assertEqual(int_vec, [1, 2345]) + + frame.put_parameter("float_vec", [1.23, 4.56, 7.89]) + vec = frame.get_parameter("float_vec", as_type="double") + self.assertEqual(len(vec), 3) + self.assertEqual(vec, [1.23, 4.56, 7.89]) + + frame.put_parameter("real_float_vec", [1.23, 4.56, 7.89], as_type="float") + f_vec = frame.get_parameter("real_float_vec", as_type="float") + self.assertEqual(len(f_vec), 3) + self.assertEqual(vec, [1.23, 4.56, 7.89]) + + frame.put_parameter("float_as_float", 3.14, as_type="float") + self.assertAlmostEqual(frame.get_parameter("float_as_float"), 3.14, places=5) + class FrameReadTest(unittest.TestCase): """Unit tests for the Frame python bindings for Frames read from file. diff --git a/python/podio/test_utils.py b/python/podio/test_utils.py index 2c5e282b6..44efc9cce 100644 --- a/python/podio/test_utils.py +++ b/python/podio/test_utils.py @@ -2,5 +2,57 @@ """Utilities for python unittests""" import os +import ROOT +ROOT.gSystem.Load("libTestDataModelDict.so") # noqa: E402 +from ROOT import ExampleHitCollection, ExampleClusterCollection # noqa: E402 # pylint: disable=wrong-import-position -SKIP_SIO_TESTS = os.environ.get('SKIP_SIO_TESTS', '1') == '1' +from podio.frame import Frame # pylint: disable=wrong-import-position + + +SKIP_SIO_TESTS = os.environ.get("SKIP_SIO_TESTS", "1") == "1" + + +def create_hit_collection(): + """Create a simple hit collection with two hits for testing""" + hits = ExampleHitCollection() + hits.create(0xBAD, 0.0, 0.0, 0.0, 23.0) + hits.create(0xCAFFEE, 1.0, 0.0, 0.0, 12.0) + + return hits + + +def create_cluster_collection(): + """Create a simple cluster collection with two clusters""" + clusters = ExampleClusterCollection() + clu0 = clusters.create() + clu0.energy(3.14) + clu1 = clusters.create() + clu1.energy(1.23) + + return clusters + + +def create_frame(): + """Create a frame with an ExampleHit and an ExampleCluster collection""" + frame = Frame() + hits = create_hit_collection() + frame.put(hits, "hits_from_python") + clusters = create_cluster_collection() + frame.put(clusters, "clusters_from_python") + + frame.put_parameter("an_int", 42) + frame.put_parameter("some_floats", [1.23, 7.89, 3.14]) + frame.put_parameter("greetings", ["from", "python"]) + frame.put_parameter("real_float", 3.14, as_type="float") + frame.put_parameter("more_real_floats", [1.23, 4.56, 7.89], as_type="float") + + return frame + + +def write_file(writer_type, filename): + """Write a file using the given Writer type and put one Frame into it under + the events category + """ + writer = writer_type(filename) + event = create_frame() + writer.write_frame(event, "events") diff --git a/tests/CTestCustom.cmake b/tests/CTestCustom.cmake index b0e683f65..d4d05cd2a 100644 --- a/tests/CTestCustom.cmake +++ b/tests/CTestCustom.cmake @@ -22,6 +22,8 @@ if ((NOT "@FORCE_RUN_ALL_TESTS@" STREQUAL "ON") AND (NOT "@USE_SANITIZER@" STREQ read-legacy-files-root_v00-13 read_frame_legacy_root read_frame_root_multiple + write_python_frame_root + read_python_frame_root write_frame_root read_frame_root @@ -35,6 +37,8 @@ if ((NOT "@FORCE_RUN_ALL_TESTS@" STREQUAL "ON") AND (NOT "@USE_SANITIZER@" STREQ write_frame_sio read_frame_sio read_frame_legacy_sio + write_python_frame_sio + read_python_frame_sio write_ascii diff --git a/tests/read_python_frame.h b/tests/read_python_frame.h new file mode 100644 index 000000000..5a06cc4ce --- /dev/null +++ b/tests/read_python_frame.h @@ -0,0 +1,106 @@ +#ifndef PODIO_TESTS_READ_PYTHON_FRAME_H // NOLINT(llvm-header-guard): folder structure not suitable +#define PODIO_TESTS_READ_PYTHON_FRAME_H // NOLINT(llvm-header-guard): folder structure not suitable + +#include "datamodel/ExampleClusterCollection.h" +#include "datamodel/ExampleHitCollection.h" + +#include "podio/Frame.h" + +#include + +int checkHits(const ExampleHitCollection& hits) { + if (hits.size() != 2) { + std::cerr << "There should be two hits in the collection (actual size: " << hits.size() << ")" << std::endl; + return 1; + } + + auto hit1 = hits[0]; + if (hit1.cellID() != 0xbad || hit1.x() != 0.0 || hit1.y() != 0.0 || hit1.z() != 0.0 || hit1.energy() != 23.0) { + std::cerr << "Could not retrieve the correct hit[0]: (expected: " << ExampleHit(0xbad, 0.0, 0.0, 0.0, 23.0) + << ", actual: " << hit1 << ")" << std::endl; + return 1; + } + + auto hit2 = hits[1]; + if (hit2.cellID() != 0xcaffee || hit2.x() != 1.0 || hit2.y() != 0.0 || hit2.z() != 0.0 || hit2.energy() != 12.0) { + std::cerr << "Could not retrieve the correct hit[1]: (expected: " << ExampleHit(0xcaffee, 1.0, 0.0, 0.0, 12.0) + << ", actual: " << hit1 << ")" << std::endl; + return 1; + } + + return 0; +} + +int checkClusters(const ExampleClusterCollection& clusters) { + if (clusters.size() != 2) { + std::cerr << "There should be two clusters in the collection (actual size: " << clusters.size() << ")" << std::endl; + return 1; + } + + if (clusters[0].energy() != 3.14 || clusters[1].energy() != 1.23) { + std::cerr << "Energies of the clusters is wrong: (expected: 3.14 and 1.23, actual " << clusters[0].energy() + << " and " << clusters[1].energy() << ")" << std::endl; + return 1; + } + + return 0; +} + +template +std::ostream& operator<<(std::ostream& o, const std::vector& vec) { + auto delim = "["; + for (const auto& v : vec) { + o << std::exchange(delim, ", ") << v; + } + return o << "]"; +} + +int checkParameters(const podio::Frame& frame) { + const auto iVal = frame.getParameter("an_int"); + if (iVal != 42) { + std::cerr << "Parameter an_int was not stored correctly (expected 42, actual " << iVal << ")" << std::endl; + return 1; + } + + const auto& dVal = frame.getParameter>("some_floats"); + if (dVal.size() != 3 || dVal[0] != 1.23 || dVal[1] != 7.89 || dVal[2] != 3.14) { + std::cerr << "Parameter some_floats was not stored correctly (expected [1.23, 7.89, 3.14], actual " << dVal << ")" + << std::endl; + return 1; + } + + const auto& strVal = frame.getParameter>("greetings"); + if (strVal.size() != 2 || strVal[0] != "from" || strVal[1] != "python") { + std::cerr << "Parameter greetings was not stored correctly (expected [from, python], actual " << strVal << ")" + << std::endl; + return 1; + } + + const auto realFloat = frame.getParameter("real_float"); + if (realFloat != 3.14f) { + std::cerr << "Parameter real_float was not stored correctly (expected 3.14, actual " << realFloat << ")" + << std::endl; + return 1; + } + + const auto& realFloats = frame.getParameter>("more_real_floats"); + if (realFloats.size() != 3 || realFloats[0] != 1.23f || realFloats[1] != 4.56f || realFloats[2] != 7.89f) { + std::cerr << "Parameter more_real_floats was not stored as correctly (expected [1.23, 4.56, 7.89], actual" + << realFloats << ")" << std::endl; + } + + return 0; +} + +template +int read_frame(const std::string& filename) { + auto reader = ReaderT(); + reader.openFile(filename); + + auto event = podio::Frame(reader.readEntry("events", 0)); + + return checkHits(event.get("hits_from_python")) + + checkClusters(event.get("clusters_from_python")) + checkParameters(event); +} + +#endif // PODIO_TESTS_READ_PYTHON_FRAME_H diff --git a/tests/root_io/CMakeLists.txt b/tests/root_io/CMakeLists.txt index bfa8309f3..5c867a9fe 100644 --- a/tests/root_io/CMakeLists.txt +++ b/tests/root_io/CMakeLists.txt @@ -11,6 +11,7 @@ set(root_dependent_tests write_frame_root.cpp read_frame_legacy_root.cpp read_frame_root_multiple.cpp + read_python_frame_root.cpp ) if(ENABLE_RNTUPLE) set(root_dependent_tests @@ -69,3 +70,9 @@ if (DEFINED CACHE{PODIO_TEST_INPUT_DATA_DIR}) ADD_PODIO_LEGACY_TEST(${version} read_frame_legacy_root example.root legacy_test_cases) endforeach() endif() + +#--- Write via python and the ROOT backend and see if we can read it back in in +#--- c++ +add_test(NAME write_python_frame_root COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/write_frame_root.py) +PODIO_SET_TEST_ENV(write_python_frame_root) +set_property(TEST read_python_frame_root PROPERTY DEPENDS write_python_frame_root) diff --git a/tests/root_io/read_python_frame_root.cpp b/tests/root_io/read_python_frame_root.cpp new file mode 100644 index 000000000..23d1c0015 --- /dev/null +++ b/tests/root_io/read_python_frame_root.cpp @@ -0,0 +1,7 @@ +#include "read_python_frame.h" + +#include "podio/ROOTFrameReader.h" + +int main() { + return read_frame("example_frame_with_py.root"); +} diff --git a/tests/root_io/write_frame_root.py b/tests/root_io/write_frame_root.py new file mode 100644 index 000000000..38bece171 --- /dev/null +++ b/tests/root_io/write_frame_root.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +"""Script to write a Frame in ROOT format""" + +from podio import test_utils +from podio.root_io import Writer + +test_utils.write_file(Writer, "example_frame_with_py.root") diff --git a/tests/sio_io/read_python_frame_sio.cpp b/tests/sio_io/read_python_frame_sio.cpp new file mode 100644 index 000000000..61c3eb481 --- /dev/null +++ b/tests/sio_io/read_python_frame_sio.cpp @@ -0,0 +1,7 @@ +#include "read_python_frame.h" + +#include "podio/SIOFrameReader.h" + +int main() { + return read_frame("example_frame_with_py.sio"); +} diff --git a/tests/sio_io/write_frame_sio.py b/tests/sio_io/write_frame_sio.py new file mode 100644 index 000000000..94e08aa27 --- /dev/null +++ b/tests/sio_io/write_frame_sio.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +"""Script to write a Frame in SIO format""" + +from podio import test_utils +from podio.sio_io import Writer + +test_utils.write_file(Writer, "example_frame_with_py.sio")