Skip to content

Commit

Permalink
refactor: Improved Python bindings for algebra types (acts-project#3611)
Browse files Browse the repository at this point in the history
This aligns the algebra API in python more closely with the C++ one. While I don't want to add too much binding code, I think it's preferable to stick to the Eigen API rather than make up extra custom API just for the sake of exposing it to Python.

To this end, I'm exposing more constructors for the `Vector` classes, and adding `Translation3` and `AngleAxis3` as Python types, and allow operators for multiplication and so on between transforms.

At the same time, I'm using this PR to also align some enums we expose to the C++ ones.
  • Loading branch information
paulgessinger authored Sep 19, 2024
1 parent 1451304 commit c334ce4
Showing 1 changed file with 77 additions and 32 deletions.
109 changes: 77 additions & 32 deletions Examples/Python/src/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,12 @@ void addAlgebra(Acts::Python::Context& ctx) {
return v;
}))
.def("__getitem__",
[](const Acts::Vector2& self, Eigen::Index i) { return self[i]; });
[](const Acts::Vector2& self, Eigen::Index i) { return self[i]; })
.def("__str__", [](const Acts::Vector3& self) {
std::stringstream ss;
ss << self.transpose();
return ss.str();
});

py::class_<Acts::Vector3>(m, "Vector3")
.def(py::init<double, double, double>())
Expand All @@ -277,8 +282,17 @@ void addAlgebra(Acts::Python::Context& ctx) {
v << a[0], a[1], a[2];
return v;
}))
.def_static("UnitX", []() -> Vector3 { return Acts::Vector3::UnitX(); })
.def_static("UnitY", []() -> Vector3 { return Acts::Vector3::UnitY(); })
.def_static("UnitZ", []() -> Vector3 { return Acts::Vector3::UnitZ(); })

.def("__getitem__",
[](const Acts::Vector3& self, Eigen::Index i) { return self[i]; });
[](const Acts::Vector3& self, Eigen::Index i) { return self[i]; })
.def("__str__", [](const Acts::Vector3& self) {
std::stringstream ss;
ss << self.transpose();
return ss.str();
});

py::class_<Acts::Vector4>(m, "Vector4")
.def(py::init<double, double, double, double>())
Expand All @@ -291,43 +305,74 @@ void addAlgebra(Acts::Python::Context& ctx) {
[](const Acts::Vector4& self, Eigen::Index i) { return self[i]; });

py::class_<Acts::Transform3>(m, "Transform3")
.def(py::init([](std::array<double, 3> translation) {
Acts::Transform3 t = Acts::Transform3::Identity();
t.pretranslate(
Acts::Vector3(translation[0], translation[1], translation[2]));
return t;
.def(py::init<>())
.def(py::init([](const Vector3& translation) -> Transform3 {
return Transform3{Translation3{translation}};
}))
.def("getTranslation",
[](const Acts::Transform3& self) {
return Vector3(self.translation());
.def_property_readonly("translation",
[](const Acts::Transform3& self) -> Vector3 {
return self.translation();
})
.def_static("Identity", &Acts::Transform3::Identity)
.def("__mul__",
[](const Acts::Transform3& self, const Acts::Transform3& other) {
return self * other;
})
.def("__mul__",
[](const Acts::Transform3& self, const Acts::Translation3& other) {
return self * other;
})
.def_static("Identity", &Acts::Transform3::Identity);
;
.def("__mul__",
[](const Acts::Transform3& self, const Acts::AngleAxis3& other) {
return self * other;
})
.def("__str__", [](const Acts::Transform3& self) {
std::stringstream ss;
ss << self.matrix();
return ss.str();
});

py::class_<Acts::Translation3>(m, "Translation3")
.def(py::init(
[](const Acts::Vector3& a) { return Acts::Translation3(a); }))
.def(py::init([](std::array<double, 3> a) {
return Acts::Translation3(Acts::Vector3(a[0], a[1], a[2]));
}))
.def("__str__", [](const Acts::Translation3& self) {
std::stringstream ss;
ss << self.translation().transpose();
return ss.str();
});

py::class_<Acts::AngleAxis3>(m, "AngleAxis3")
.def(py::init([](double angle, const Acts::Vector3& axis) {
return Acts::AngleAxis3(angle, axis);
}))
.def("__str__", [](const Acts::Transform3& self) {
std::stringstream ss;
ss << self.matrix();
return ss.str();
});
}

void addBinning(Context& ctx) {
auto& m = ctx.get("main");
auto binning = m.def_submodule("Binning", "");

auto binningValue = py::enum_<Acts::BinningValue>(binning, "BinningValue")
.value("x", Acts::BinningValue::binX)
.value("y", Acts::BinningValue::binY)
.value("z", Acts::BinningValue::binZ)
.value("r", Acts::BinningValue::binR)
.value("phi", Acts::BinningValue::binPhi)
.export_values();

auto boundaryType =
py::enum_<Acts::AxisBoundaryType>(binning, "AxisBoundaryType")
.value("bound", Acts::AxisBoundaryType::Bound)
.value("closed", Acts::AxisBoundaryType::Closed)
.value("open", Acts::AxisBoundaryType::Open)
.export_values();

auto axisType = py::enum_<Acts::AxisType>(binning, "AxisType")

auto binningValue = py::enum_<Acts::BinningValue>(m, "BinningValue")
.value("binX", Acts::BinningValue::binX)
.value("binY", Acts::BinningValue::binY)
.value("binZ", Acts::BinningValue::binZ)
.value("binR", Acts::BinningValue::binR)
.value("binPhi", Acts::BinningValue::binPhi);

auto boundaryType = py::enum_<Acts::AxisBoundaryType>(m, "AxisBoundaryType")
.value("bound", Acts::AxisBoundaryType::Bound)
.value("closed", Acts::AxisBoundaryType::Closed)
.value("open", Acts::AxisBoundaryType::Open);

auto axisType = py::enum_<Acts::AxisType>(m, "AxisType")
.value("equidistant", Acts::AxisType::Equidistant)
.value("variable", Acts::AxisType::Variable)
.export_values();
.value("variable", Acts::AxisType::Variable);
}

} // namespace Acts::Python

0 comments on commit c334ce4

Please sign in to comment.