Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] add ABI argument attributes in indirect calls #126896

Merged
merged 2 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,36 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
call, resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
llvmCall.setArgAttrsAttr(argAttrs);
if (mlir::ArrayAttr argAttrsArray = call.getArgAttrsAttr()) {
// sret and byval type needs to be converted.
auto convertTypeAttr = [&](const mlir::NamedAttribute &attr) {
return mlir::TypeAttr::get(convertType(
llvm::cast<mlir::TypeAttr>(attr.getValue()).getValue()));
};
llvm::SmallVector<mlir::Attribute> newArgAttrsArray;
for (auto argAttrs : argAttrsArray) {
llvm::SmallVector<mlir::NamedAttribute> convertedAttrs;
for (const mlir::NamedAttribute &attr :
llvm::cast<mlir::DictionaryAttr>(argAttrs)) {
if (attr.getName().getValue() ==
mlir::LLVM::LLVMDialect::getByValAttrName()) {
convertedAttrs.push_back(rewriter.getNamedAttr(
mlir::LLVM::LLVMDialect::getByValAttrName(),
convertTypeAttr(attr)));
} else if (attr.getName().getValue() ==
mlir::LLVM::LLVMDialect::getStructRetAttrName()) {
convertedAttrs.push_back(rewriter.getNamedAttr(
mlir::LLVM::LLVMDialect::getStructRetAttrName(),
convertTypeAttr(attr)));
} else {
convertedAttrs.push_back(attr);
}
}
newArgAttrsArray.emplace_back(
mlir::DictionaryAttr::get(rewriter.getContext(), convertedAttrs));
}
llvmCall.setArgAttrsAttr(rewriter.getArrayAttr(newArgAttrsArray));
}
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
llvmCall.setResAttrsAttr(resAttrs);
return mlir::success();
Expand Down
43 changes: 34 additions & 9 deletions flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,19 +534,44 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
} else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
fir::CallOp newCall;
if (callOp.getCallee()) {
newCall =
rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
newCall = rewriter->create<fir::CallOp>(loc, *callOp.getCallee(),
newResTys, newOpers);
} else {
// TODO: llvm dialect must be updated to propagate argument on
// attributes for indirect calls. See:
// https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
if (hasByValOrSRetArgs(newInTyAndAttrs))
TODO(loc,
"passing argument or result on the stack in indirect calls");
newOpers[0].setType(mlir::FunctionType::get(
callOp.getContext(),
mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
newCall = rewriter->create<A>(loc, newResTys, newOpers);
newCall = rewriter->create<fir::CallOp>(loc, newResTys, newOpers);
// Set ABI argument attributes on call operation since they are not
// accessible via a FuncOp in indirect calls.
if (hasByValOrSRetArgs(newInTyAndAttrs)) {
llvm::SmallVector<mlir::Attribute> argAttrsArray;
for (const auto &arg :
llvm::ArrayRef<fir::CodeGenSpecifics::TypeAndAttr>(
newInTyAndAttrs)
.drop_front(dropFront)) {
mlir::NamedAttrList argAttrs;
const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
if (attr.isByVal()) {
mlir::Type elemType =
fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
argAttrs.set(mlir::LLVM::LLVMDialect::getByValAttrName(),
mlir::TypeAttr::get(elemType));
} else if (attr.isSRet()) {
mlir::Type elemType =
fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
argAttrs.set(mlir::LLVM::LLVMDialect::getStructRetAttrName(),
mlir::TypeAttr::get(elemType));
if (auto align = attr.getAlignment()) {
argAttrs.set(mlir::LLVM::LLVMDialect::getAlignAttrName(),
rewriter->getIntegerAttr(
rewriter->getIntegerType(32), align));
}
}
argAttrsArray.emplace_back(
argAttrs.getDictionary(rewriter->getContext()));
}
newCall.setArgAttrsAttr(rewriter->getArrayAttr(argAttrsArray));
}
}
LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
if (wrap)
Expand Down
14 changes: 14 additions & 0 deletions flang/test/Fir/convert-to-llvm.fir
Original file line number Diff line number Diff line change
Expand Up @@ -2871,3 +2871,17 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
return %0 : i16
}

// CHECK-LABEL: @test_byval
func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
return
}

// CHECK-LABEL: @test_sret
func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
return
}
22 changes: 22 additions & 0 deletions flang/test/Fir/target-rewrite-indirect-calls.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Test that ABI attributes are set in indirect calls to BIND(C) functions.
// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s

func.func @test(%arg0: () -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
%0 = fir.load %arg1 : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
%1 = fir.convert %arg0 : (() -> ()) -> ((!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ())
fir.call %1(%0, %arg2) proc_attrs<bind_c> : (!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ()
return
}
// CHECK-LABEL: func.func @test(
// CHECK-SAME: %[[VAL_0:.*]]: () -> (),
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>,
// CHECK-SAME: %[[VAL_2:.*]]: f64) {
// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_0]] : (() -> ()) -> ((!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> ())
// CHECK: %[[VAL_5:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_6:.*]] = fir.alloca !fir.type<t{a:!fir.array<5xf64>}>
// CHECK: fir.store %[[VAL_3]] to %[[VAL_6]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
// CHECK: fir.call %[[VAL_4]](%[[VAL_6]], %[[VAL_2]]) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
// CHECK: llvm.intr.stackrestore %[[VAL_5]] : !llvm.ptr
// CHECK: return
// CHECK: }
15 changes: 15 additions & 0 deletions flang/test/Integration/abi-indirect-call.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!REQUIRES: x86-registered-target
!REQUIRES: flang-supports-f128-math
!RUN: %flang_fc1 -emit-llvm -triple x86_64-unknown-linux-gnu %s -o - | FileCheck %s

! Test ABI of indirect calls is properly implemented in the LLVM IR.

subroutine foo(func_ptr, z)
interface
complex(16) function func_ptr()
end function
end interface
complex(16) :: z
! CHECK: call void %{{.*}}(ptr sret({ fp128, fp128 }) align 16 %{{.*}})
z = func_ptr()
end subroutine