From c334ce474bdf2f96ccd8f5e66b959b6bdba7b85e Mon Sep 17 00:00:00 2001 From: Paul Gessinger Date: Thu, 19 Sep 2024 13:03:36 +0200 Subject: [PATCH] refactor: Improved Python bindings for algebra types (#3611) This aligns the algebra API in python more closely with the C++ one. While I don't want to add too much binding code, I think it's preferable to stick to the Eigen API rather than make up extra custom API just for the sake of exposing it to Python. To this end, I'm exposing more constructors for the `Vector` classes, and adding `Translation3` and `AngleAxis3` as Python types, and allow operators for multiplication and so on between transforms. At the same time, I'm using this PR to also align some enums we expose to the C++ ones. --- Examples/Python/src/Base.cpp | 109 +++++++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 32 deletions(-) diff --git a/Examples/Python/src/Base.cpp b/Examples/Python/src/Base.cpp index aae48a6a3b4..e92e4d17895 100644 --- a/Examples/Python/src/Base.cpp +++ b/Examples/Python/src/Base.cpp @@ -268,7 +268,12 @@ void addAlgebra(Acts::Python::Context& ctx) { return v; })) .def("__getitem__", - [](const Acts::Vector2& self, Eigen::Index i) { return self[i]; }); + [](const Acts::Vector2& self, Eigen::Index i) { return self[i]; }) + .def("__str__", [](const Acts::Vector3& self) { + std::stringstream ss; + ss << self.transpose(); + return ss.str(); + }); py::class_(m, "Vector3") .def(py::init()) @@ -277,8 +282,17 @@ void addAlgebra(Acts::Python::Context& ctx) { v << a[0], a[1], a[2]; return v; })) + .def_static("UnitX", []() -> Vector3 { return Acts::Vector3::UnitX(); }) + .def_static("UnitY", []() -> Vector3 { return Acts::Vector3::UnitY(); }) + .def_static("UnitZ", []() -> Vector3 { return Acts::Vector3::UnitZ(); }) + .def("__getitem__", - [](const Acts::Vector3& self, Eigen::Index i) { return self[i]; }); + [](const Acts::Vector3& self, Eigen::Index i) { return self[i]; }) + .def("__str__", [](const Acts::Vector3& self) { + std::stringstream ss; + ss << self.transpose(); + return ss.str(); + }); py::class_(m, "Vector4") .def(py::init()) @@ -291,43 +305,74 @@ void addAlgebra(Acts::Python::Context& ctx) { [](const Acts::Vector4& self, Eigen::Index i) { return self[i]; }); py::class_(m, "Transform3") - .def(py::init([](std::array translation) { - Acts::Transform3 t = Acts::Transform3::Identity(); - t.pretranslate( - Acts::Vector3(translation[0], translation[1], translation[2])); - return t; + .def(py::init<>()) + .def(py::init([](const Vector3& translation) -> Transform3 { + return Transform3{Translation3{translation}}; })) - .def("getTranslation", - [](const Acts::Transform3& self) { - return Vector3(self.translation()); + .def_property_readonly("translation", + [](const Acts::Transform3& self) -> Vector3 { + return self.translation(); + }) + .def_static("Identity", &Acts::Transform3::Identity) + .def("__mul__", + [](const Acts::Transform3& self, const Acts::Transform3& other) { + return self * other; + }) + .def("__mul__", + [](const Acts::Transform3& self, const Acts::Translation3& other) { + return self * other; }) - .def_static("Identity", &Acts::Transform3::Identity); - ; + .def("__mul__", + [](const Acts::Transform3& self, const Acts::AngleAxis3& other) { + return self * other; + }) + .def("__str__", [](const Acts::Transform3& self) { + std::stringstream ss; + ss << self.matrix(); + return ss.str(); + }); + + py::class_(m, "Translation3") + .def(py::init( + [](const Acts::Vector3& a) { return Acts::Translation3(a); })) + .def(py::init([](std::array a) { + return Acts::Translation3(Acts::Vector3(a[0], a[1], a[2])); + })) + .def("__str__", [](const Acts::Translation3& self) { + std::stringstream ss; + ss << self.translation().transpose(); + return ss.str(); + }); + + py::class_(m, "AngleAxis3") + .def(py::init([](double angle, const Acts::Vector3& axis) { + return Acts::AngleAxis3(angle, axis); + })) + .def("__str__", [](const Acts::Transform3& self) { + std::stringstream ss; + ss << self.matrix(); + return ss.str(); + }); } void addBinning(Context& ctx) { auto& m = ctx.get("main"); - auto binning = m.def_submodule("Binning", ""); - - auto binningValue = py::enum_(binning, "BinningValue") - .value("x", Acts::BinningValue::binX) - .value("y", Acts::BinningValue::binY) - .value("z", Acts::BinningValue::binZ) - .value("r", Acts::BinningValue::binR) - .value("phi", Acts::BinningValue::binPhi) - .export_values(); - - auto boundaryType = - py::enum_(binning, "AxisBoundaryType") - .value("bound", Acts::AxisBoundaryType::Bound) - .value("closed", Acts::AxisBoundaryType::Closed) - .value("open", Acts::AxisBoundaryType::Open) - .export_values(); - - auto axisType = py::enum_(binning, "AxisType") + + auto binningValue = py::enum_(m, "BinningValue") + .value("binX", Acts::BinningValue::binX) + .value("binY", Acts::BinningValue::binY) + .value("binZ", Acts::BinningValue::binZ) + .value("binR", Acts::BinningValue::binR) + .value("binPhi", Acts::BinningValue::binPhi); + + auto boundaryType = py::enum_(m, "AxisBoundaryType") + .value("bound", Acts::AxisBoundaryType::Bound) + .value("closed", Acts::AxisBoundaryType::Closed) + .value("open", Acts::AxisBoundaryType::Open); + + auto axisType = py::enum_(m, "AxisType") .value("equidistant", Acts::AxisType::Equidistant) - .value("variable", Acts::AxisType::Variable) - .export_values(); + .value("variable", Acts::AxisType::Variable); } } // namespace Acts::Python