Skip to content

Commit

Permalink
[mlir][python] extend LLVM bindings (#89797)
Browse files Browse the repository at this point in the history
Add bindings for LLVM pointer type.
  • Loading branch information
makslevental authored Apr 24, 2024
1 parent 6e9ea6e commit 79d4d16
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 8 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir-c/Dialect/LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
unsigned addressSpace);

/// Returns `true` if the type is an LLVM dialect pointer type.
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type);

/// Returns address space of llvm.ptr
MLIR_CAPI_EXPORTED unsigned
mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType);

/// Creates an llmv.void type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx);

Expand Down
43 changes: 35 additions & 8 deletions mlir/lib/Bindings/Python/DialectLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ using namespace mlir::python;
using namespace mlir::python::adaptors;

void populateDialectLLVMSubmodule(const pybind11::module &m) {

//===--------------------------------------------------------------------===//
// StructType
//===--------------------------------------------------------------------===//

auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);

Expand All @@ -35,25 +40,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
}
return cls(type);
},
py::arg("cls"), py::arg("elements"), py::kw_only(),
py::arg("packed") = false, py::arg("loc") = py::none());
"cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
"loc"_a = py::none());

llvmStructType.def_classmethod(
"get_identified",
[](py::object cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeIdentifiedGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
py::arg("cls"), py::arg("name"), py::kw_only(),
py::arg("context") = py::none());
"cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());

llvmStructType.def_classmethod(
"get_opaque",
[](py::object cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeOpaqueGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
"cls"_a, "name"_a, "context"_a = py::none());

llvmStructType.def(
"set_body",
Expand All @@ -65,7 +69,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
"Struct body already set to different content.");
}
},
py::arg("elements"), py::kw_only(), py::arg("packed") = false);
"elements"_a, py::kw_only(), "packed"_a = false);

llvmStructType.def_classmethod(
"new_identified",
Expand All @@ -75,8 +79,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
ctx, mlirStringRefCreate(name.data(), name.length()),
elements.size(), elements.data(), packed));
},
py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
py::arg("packed") = false, py::arg("context") = py::none());
"cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
"context"_a = py::none());

llvmStructType.def_property_readonly(
"name", [](MlirType type) -> std::optional<std::string> {
Expand Down Expand Up @@ -105,6 +109,29 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {

llvmStructType.def_property_readonly(
"opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });

//===--------------------------------------------------------------------===//
// PointerType
//===--------------------------------------------------------------------===//

mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
.def_classmethod(
"get",
[](py::object cls, std::optional<unsigned> addressSpace,
MlirContext context) {
CollectDiagnosticsToStringScope scope(context);
MlirType type = mlirLLVMPointerTypeGet(
context, addressSpace.has_value() ? *addressSpace : 0);
if (mlirTypeIsNull(type)) {
throw py::value_error(scope.takeMessage());
}
return cls(type);
},
"cls"_a, "address_space"_a = py::none(), py::kw_only(),
"context"_a = py::none())
.def_property_readonly("address_space", [](MlirType type) {
return mlirLLVMPointerTypeGetAddressSpace(type);
});
}

PYBIND11_MODULE(_mlirDialectsLLVM, m) {
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/CAPI/Dialect/LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
}

bool mlirTypeIsALLVMPointerType(MlirType type) {
return isa<LLVM::LLVMPointerType>(unwrap(type));
}

unsigned mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType) {
return cast<LLVM::LLVMPointerType>(unwrap(pointerType)).getAddressSpace();
}

MlirType mlirLLVMVoidTypeGet(MlirContext ctx) {
return wrap(LLVMVoidType::get(unwrap(ctx)));
}
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
#define PYTHON_BINDINGS_LLVM_OPS

include "mlir/Dialect/LLVMIR/LLVMOps.td"
include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td"

#endif
8 changes: 8 additions & 0 deletions mlir/python/mlir/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@
from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
from .._mlir_libs._mlirDialectsLLVM import *
from ..ir import Value
from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results


def mlir_constant(value, *, loc=None, ip=None) -> Value:
return _get_op_result_or_op_results(
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
)
43 changes: 43 additions & 0 deletions mlir/test/python/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,46 @@ def testSmoke():
)
result = llvm.UndefOp(mat64f32_t)
# CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>


# CHECK-LABEL: testPointerType
@constructAndPrintInModule
def testPointerType():
ptr = llvm.PointerType.get()
# CHECK: !llvm.ptr
print(ptr)

ptr_with_addr = llvm.PointerType.get(1)
# CHECK: !llvm.ptr<1>
print(ptr_with_addr)


# CHECK-LABEL: testConstant
@constructAndPrintInModule
def testConstant():
i32 = IntegerType.get_signless(32)
c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
# CHECK: %{{.*}} = llvm.mlir.constant(128 : i32) : i32
print(c_128.owner)


# CHECK-LABEL: testIntrinsics
@constructAndPrintInModule
def testIntrinsics():
i32 = IntegerType.get_signless(32)
ptr = llvm.PointerType.get()
c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
# CHECK: %[[CST128:.*]] = llvm.mlir.constant(128 : i32) : i32
print(c_128.owner)

alloca = llvm.alloca(ptr, c_128, i32)
# CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CST128]] x i32 : (i32) -> !llvm.ptr
print(alloca.owner)

c_0 = llvm.mlir_constant(IntegerAttr.get(IntegerType.get_signless(8), 0))
# CHECK: %[[CST0:.+]] = llvm.mlir.constant(0 : i8) : i8
print(c_0.owner)

result = llvm.intr_memset(alloca, c_0, c_128, False)
# CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[CST0]], %[[CST128]]) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
print(result)

0 comments on commit 79d4d16

Please sign in to comment.