Skip to content

Commit

Permalink
[CIR][NFC] move data member pointer lowering to CXXABI
Browse files Browse the repository at this point in the history
This patch moves the lowering code for data member pointers from the conversion
patterns to the implementation of CXXABI because this part should be ABI-
specific.
  • Loading branch information
Lancern committed Nov 23, 2024
1 parent 4aca8d4 commit 94f8d28
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 55 deletions.
26 changes: 26 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
#define LLVM_CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRCXXABI_H

#include "LowerFunctionInfo.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Target/AArch64.h"

namespace cir {
Expand Down Expand Up @@ -59,6 +65,26 @@ class CIRCXXABI {
/// Returns how an argument of the given record type should be passed.
/// FIXME(cir): This expects a CXXRecordDecl! Not any record type.
virtual RecordArgABI getRecordArgABI(const StructType RD) const = 0;

/// Lower the given data member pointer type to its ABI type. The returned
/// type is also a CIR type.
virtual mlir::Type
lowerDataMemberType(cir::DataMemberType type,
const mlir::TypeConverter &typeConverter) const = 0;

/// Lower the given data member pointer constant to a constant of the ABI
/// type. The returned constant is represented as an attribute as well.
virtual mlir::TypedAttr
lowerDataMemberConstant(cir::DataMemberAttr attr,
const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const = 0;

/// Lower the given cir.get_runtime_member op to a sequence of more
/// "primitive" CIR operations that act on the ABI types.
virtual mlir::Operation *
lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
mlir::Value loweredAddr, mlir::Value loweredMember,
mlir::OpBuilder &builder) const = 0;
};

/// Creates an Itanium-family ABI.
Expand Down
62 changes: 62 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "../LoweringPrepareCXXABI.h"
#include "CIRCXXABI.h"
#include "LowerModule.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/ErrorHandling.h"

namespace cir {
Expand Down Expand Up @@ -51,6 +52,19 @@ class ItaniumCXXABI : public CIRCXXABI {
cir_cconv_assert(!cir::MissingFeatures::recordDeclCanPassInRegisters());
return RAA_Default;
}

mlir::Type
lowerDataMemberType(cir::DataMemberType type,
const mlir::TypeConverter &typeConverter) const override;

mlir::TypedAttr lowerDataMemberConstant(
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const override;

mlir::Operation *
lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
mlir::Value loweredAddr, mlir::Value loweredMember,
mlir::OpBuilder &builder) const override;
};

} // namespace
Expand All @@ -67,6 +81,54 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const {
return false;
}

mlir::Type ItaniumCXXABI::lowerDataMemberType(
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
// Itanium C++ ABI 2.3:
// A pointer to data member is an offset from the base address of
// the class object containing it, represented as a ptrdiff_t
const clang::TargetInfo &target = LM.getTarget();
clang::TargetInfo::IntType ptrdiffTy =
target.getPtrDiffType(clang::LangAS::Default);
return cir::IntType::get(type.getContext(), target.getTypeWidth(ptrdiffTy),
target.isTypeSigned(ptrdiffTy));
}

mlir::TypedAttr ItaniumCXXABI::lowerDataMemberConstant(
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const {
uint64_t memberOffset;
if (attr.isNullPtr()) {
// Itanium C++ ABI 2.3:
// A NULL pointer is represented as -1.
memberOffset = -1ull;
} else {
// Itanium C++ ABI 2.3:
// A pointer to data member is an offset from the base address of
// the class object containing it, represented as a ptrdiff_t
auto memberIndex = attr.getMemberIndex().value();
memberOffset =
attr.getType().getClsTy().getElementOffset(layout, memberIndex);
}

mlir::Type abiTy = lowerDataMemberType(attr.getType(), typeConverter);
return cir::IntAttr::get(abiTy, memberOffset);
}

mlir::Operation *ItaniumCXXABI::lowerGetRuntimeMember(
cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
mlir::Value loweredAddr, mlir::Value loweredMember,
mlir::OpBuilder &builder) const {
auto byteTy = IntType::get(op.getContext(), 8, true);
auto bytePtrTy = PointerType::get(
byteTy, mlir::cast<PointerType>(op.getAddr().getType()).getAddrSpace());
auto objectBytesPtr = builder.create<CastOp>(op.getLoc(), bytePtrTy,
CastKind::bitcast, op.getAddr());
auto memberBytesPtr = builder.create<PtrStrideOp>(
op.getLoc(), bytePtrTy, objectBytesPtr, loweredMember);
return builder.create<CastOp>(op.getLoc(), op.getType(), CastKind::bitcast,
memberBytesPtr);
}

CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
switch (LM.getCXXABIKind()) {
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't
Expand Down
74 changes: 33 additions & 41 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1521,28 +1521,6 @@ bool hasTrailingZeros(cir::ConstArrayAttr attr) {
}));
}

static mlir::Attribute
lowerDataMemberAttr(mlir::ModuleOp moduleOp, cir::DataMemberAttr attr,
const mlir::TypeConverter &typeConverter) {
mlir::DataLayout layout{moduleOp};

uint64_t memberOffset;
if (attr.isNullPtr()) {
// TODO(cir): the numerical value of a null data member pointer is
// ABI-specific and should be queried through ABI.
assert(!MissingFeatures::targetCodeGenInfoGetNullPointer());
memberOffset = -1ull;
} else {
auto memberIndex = attr.getMemberIndex().value();
memberOffset =
attr.getType().getClsTy().getElementOffset(layout, memberIndex);
}

auto underlyingIntTy = mlir::IntegerType::get(
moduleOp->getContext(), layout.getTypeSizeInBits(attr.getType()));
return mlir::IntegerAttr::get(underlyingIntTy, memberOffset);
}

mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
cir::ConstantOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -1602,9 +1580,13 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
}
attr = op.getValue();
} else if (mlir::isa<cir::DataMemberType>(op.getType())) {
assert(lowerMod && "lower module is not available");
auto dataMember = mlir::cast<cir::DataMemberAttr>(op.getValue());
attr = lowerDataMemberAttr(op->getParentOfType<mlir::ModuleOp>(),
dataMember, *typeConverter);
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
mlir::TypedAttr abiValue = lowerMod->getCXXABI().lowerDataMemberConstant(
dataMember, layout, *typeConverter);
rewriter.replaceOpWithNewOp<ConstantOp>(op, abiValue);
return mlir::success();
}
// TODO(cir): constant arrays are currently just pushed into the stack using
// the store instruction, instead of being stored as global variables and
Expand Down Expand Up @@ -2208,8 +2190,15 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
return mlir::success();
} else if (auto dataMemberAttr =
mlir::dyn_cast<cir::DataMemberAttr>(init.value())) {
init = lowerDataMemberAttr(op->getParentOfType<mlir::ModuleOp>(),
dataMemberAttr, *typeConverter);
assert(lowerMod && "lower module is not available");
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
mlir::TypedAttr abiValue = lowerMod->getCXXABI().lowerDataMemberConstant(
dataMemberAttr, layout, *typeConverter);
auto abiOp = mlir::cast<GlobalOp>(rewriter.clone(*op.getOperation()));
abiOp.setInitialValueAttr(abiValue);
abiOp.setSymType(abiValue.getType());
rewriter.replaceOp(op, abiOp);
return mlir::success();
} else if (const auto structAttr =
mlir::dyn_cast<cir::ConstStructAttr>(init.value())) {
setupRegionInitializedLLVMGlobalOp(op, rewriter);
Expand Down Expand Up @@ -3237,11 +3226,11 @@ mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite(
mlir::LogicalResult CIRToLLVMGetRuntimeMemberOpLowering::matchAndRewrite(
cir::GetRuntimeMemberOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto llvmResTy = getTypeConverter()->convertType(op.getType());
auto llvmElementTy = mlir::IntegerType::get(op.getContext(), 8);

rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
op, llvmResTy, llvmElementTy, adaptor.getAddr(), adaptor.getMember());
assert(lowerMod && "lowering module is not available");
mlir::Type llvmResTy = getTypeConverter()->convertType(op.getType());
mlir::Operation *llvmOp = lowerMod->getCXXABI().lowerGetRuntimeMember(
op, llvmResTy, adaptor.getAddr(), adaptor.getMember(), rewriter);
rewriter.replaceOp(op, llvmOp);
return mlir::success();
}

Expand Down Expand Up @@ -3850,14 +3839,17 @@ mlir::LogicalResult CIRToLLVMSignBitOpLowering::matchAndRewrite(

void populateCIRToLLVMConversionPatterns(
mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter,
mlir::DataLayout &dataLayout,
mlir::DataLayout &dataLayout, cir::LowerModule *lowerModule,
llvm::StringMap<mlir::LLVM::GlobalOp> &stringGlobalsMap,
llvm::StringMap<mlir::LLVM::GlobalOp> &argStringGlobalsMap,
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> &argsVarMap) {
patterns.add<CIRToLLVMReturnOpLowering>(patterns.getContext());
patterns.add<CIRToLLVMAllocaOpLowering>(converter, dataLayout,
stringGlobalsMap, argStringGlobalsMap,
argsVarMap, patterns.getContext());
patterns.add<CIRToLLVMConstantOpLowering, CIRToLLVMGlobalOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering>(
converter, patterns.getContext(), lowerModule);
patterns.add<
// clang-format off
CIRToLLVMAbsOpLowering,
Expand Down Expand Up @@ -3891,7 +3883,6 @@ void populateCIRToLLVMConversionPatterns(
CIRToLLVMComplexImagPtrOpLowering,
CIRToLLVMComplexRealOpLowering,
CIRToLLVMComplexRealPtrOpLowering,
CIRToLLVMConstantOpLowering,
CIRToLLVMCopyOpLowering,
CIRToLLVMDerivedClassAddrOpLowering,
CIRToLLVMEhInflightOpLowering,
Expand All @@ -3902,8 +3893,6 @@ void populateCIRToLLVMConversionPatterns(
CIRToLLVMGetBitfieldOpLowering,
CIRToLLVMGetGlobalOpLowering,
CIRToLLVMGetMemberOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering,
CIRToLLVMGlobalOpLowering,
CIRToLLVMInlineAsmOpLowering,
CIRToLLVMIsConstantOpLowering,
CIRToLLVMIsFPClassOpLowering,
Expand Down Expand Up @@ -3990,10 +3979,13 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter,

return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS);
});
converter.addConversion([&](cir::DataMemberType type) -> mlir::Type {
return mlir::IntegerType::get(type.getContext(),
dataLayout.getTypeSizeInBits(type));
});
converter.addConversion(
[&, lowerModule](cir::DataMemberType type) -> mlir::Type {
assert(lowerModule && "CXXABI is not available");
mlir::Type abiType =
lowerModule->getCXXABI().lowerDataMemberType(type, converter);
return converter.convertType(abiType);
});
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
auto ty = converter.convertType(type.getEltType());
return mlir::LLVM::LLVMArrayType::get(ty, type.getSize());
Expand Down Expand Up @@ -4328,8 +4320,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> argsVarMap;

populateCIRToLLVMConversionPatterns(patterns, converter, dataLayout,
stringGlobalsMap, argStringGlobalsMap,
argsVarMap);
lowerModule.get(), stringGlobalsMap,
argStringGlobalsMap, argsVarMap);
mlir::populateFuncToLLVMConversionPatterns(converter, patterns);

mlir::ConversionTarget target(getContext());
Expand Down
25 changes: 22 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,15 @@ class CIRToLLVMStoreOpLowering

class CIRToLLVMConstantOpLowering
: public mlir::OpConversionPattern<cir::ConstantOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::ConstantOp>::OpConversionPattern;
CIRToLLVMConstantOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
setHasBoundedRewriteRecursion();
}

mlir::LogicalResult
matchAndRewrite(cir::ConstantOp op, OpAdaptor,
Expand Down Expand Up @@ -490,8 +497,15 @@ class CIRToLLVMSwitchFlatOpLowering

class CIRToLLVMGlobalOpLowering
: public mlir::OpConversionPattern<cir::GlobalOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::GlobalOp>::OpConversionPattern;
CIRToLLVMGlobalOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
setHasBoundedRewriteRecursion();
}

mlir::LogicalResult
matchAndRewrite(cir::GlobalOp op, OpAdaptor,
Expand Down Expand Up @@ -774,8 +788,13 @@ class CIRToLLVMGetMemberOpLowering

class CIRToLLVMGetRuntimeMemberOpLowering
: public mlir::OpConversionPattern<cir::GetRuntimeMemberOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::GetRuntimeMemberOp>::OpConversionPattern;
CIRToLLVMGetRuntimeMemberOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}

mlir::LogicalResult
matchAndRewrite(cir::GetRuntimeMemberOp op, OpAdaptor,
Expand Down
27 changes: 16 additions & 11 deletions clang/test/CIR/Lowering/data-member.cir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
!s64i = !cir.int<s, 64>
!structT = !cir.struct<struct "Point" {!cir.int<s, 32>, !cir.int<s, 32>, !cir.int<s, 32>}>

module @test {
module @test attributes {
cir.triple = "x86_64-unknown-linux-gnu",
llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
} {
cir.global external @pt_member = #cir.data_member<1> : !cir.data_member<!s32i in !structT>
// MLIR: llvm.mlir.global external @pt_member(4 : i64) {addr_space = 0 : i32} : i64
// LLVM: @pt_member = global i64 4
Expand All @@ -15,8 +18,8 @@ module @test {
cir.return %0 : !cir.data_member<!s32i in !structT>
}
// MLIR: llvm.func @constant() -> i64
// MLIR-NEXT: %0 = llvm.mlir.constant(4 : i64) : i64
// MLIR-NEXT: llvm.return %0 : i64
// MLIR-NEXT: %[[#VAL:]] = llvm.mlir.constant(4 : i64) : i64
// MLIR-NEXT: llvm.return %[[#VAL]] : i64
// MLIR-NEXT: }

// LLVM: define i64 @constant()
Expand All @@ -28,8 +31,8 @@ module @test {
cir.return %0 : !cir.data_member<!s32i in !structT>
}
// MLIR: llvm.func @null_constant() -> i64
// MLIR-NEXT: %0 = llvm.mlir.constant(-1 : i64) : i64
// MLIR-NEXT: llvm.return %0 : i64
// MLIR-NEXT: %[[#VAL:]] = llvm.mlir.constant(-1 : i64) : i64
// MLIR-NEXT: llvm.return %[[#VAL]] : i64
// MLIR-NEXT: }

// LLVM: define i64 @null_constant() !dbg !7 {
Expand All @@ -40,13 +43,15 @@ module @test {
%0 = cir.get_runtime_member %arg0[%arg1 : !cir.data_member<!s32i in !structT>] : !cir.ptr<!structT> -> !cir.ptr<!s32i>
cir.return %0 : !cir.ptr<!s32i>
}
// MLIR: llvm.func @get_runtime_member(%arg0: !llvm.ptr, %arg1: i64) -> !llvm.ptr
// MLIR-NEXT: %0 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// MLIR-NEXT: llvm.return %0 : !llvm.ptr
// MLIR: llvm.func @get_runtime_member(%[[ARG0:.+]]: !llvm.ptr, %[[ARG1:.+]]: i64) -> !llvm.ptr
// MLIR-NEXT: %[[#PTR:]] = llvm.bitcast %[[ARG0]] : !llvm.ptr to !llvm.ptr
// MLIR-NEXT: %[[#VAL:]] = llvm.getelementptr %[[#PTR]][%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// MLIR-NEXT: %[[#RET:]] = llvm.bitcast %[[#VAL]] : !llvm.ptr to !llvm.ptr
// MLIR-NEXT: llvm.return %[[#RET]] : !llvm.ptr
// MLIR-NEXT: }

// LLVM: define ptr @get_runtime_member(ptr %0, i64 %1)
// LLVM-NEXT: %3 = getelementptr i8, ptr %0, i64 %1
// LLVM-NEXT: ret ptr %3
// LLVM: define ptr @get_runtime_member(ptr %[[ARG0:.+]], i64 %[[ARG1:.+]])
// LLVM-NEXT: %[[#VAL:]] = getelementptr i8, ptr %[[ARG0]], i64 %[[ARG1]]
// LLVM-NEXT: ret ptr %[[#VAL]]
// LLVM-NEXT: }
}

0 comments on commit 94f8d28

Please sign in to comment.