diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index bafe0fdedf8..925c4546c1f 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -263,6 +263,13 @@ migraphx::shape to_shape(const py::buffer_info& info) { migraphx::shape::type_t t; std::size_t n = 0; + // Unsupported pybuffer types lead to undefined behaviour when comparing with migraphx type enum + if(info.format == "z") + { + MIGRAPHX_THROW( + "MIGRAPHX PYTHON: Unsupported data type. For fp8 and bf16 literals try using " + "migraphx.generate_argument with migraphx.add_literal"); + } visit_types([&](auto as) { if(info.format == py::format_descriptor::format() or (info.format == "l" and py::format_descriptor::format() == "q") or @@ -388,6 +395,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("op"), py::arg("args"), py::arg("mod_args") = std::vector{}) + .def( + "add_literal", + [](migraphx::module& mm, migraphx::argument a) { + return mm.add_literal(a.get_shape(), a.data()); + }, + py::arg("data")) .def( "add_literal", [](migraphx::module& mm, py::buffer data) {