Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] add ABI argument attributes in indirect calls #126896

Merged
merged 2 commits into from
Feb 12, 2025

Conversation

jeanPerier
Copy link
Contributor

Last piece that implements the TODO for sret and byval setting on indirect calls.

This includes a fix to the codegen last patch. I thought types in in type attributes were automatically converted in dialect conversion passes, but that is not the case. The sret and byval type needs to be converted to llvm types in codegen (mlir FuncOp conversion is doing a similar conversion).

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Feb 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-flang-codegen

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (jeanPerier)

Changes

Last piece that implements the TODO for sret and byval setting on indirect calls.

This includes a fix to the codegen last patch. I thought types in in type attributes were automatically converted in dialect conversion passes, but that is not the case. The sret and byval type needs to be converted to llvm types in codegen (mlir FuncOp conversion is doing a similar conversion).


Full diff: https://github.com/llvm/llvm-project/pull/126896.diff

4 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+30-2)
  • (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+34-9)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+14)
  • (added) flang/test/Fir/target-rewrite-indirect-calls.fir (+22)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f938d8d377465..c76b7cde55bdd 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -593,8 +593,36 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
         call, resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
-    if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
-      llvmCall.setArgAttrsAttr(argAttrs);
+    if (mlir::ArrayAttr argAttrsArray = call.getArgAttrsAttr()) {
+      // sret and byval type needs to be converted.
+      auto convertTypeAttr = [&](const mlir::NamedAttribute &attr) {
+        return mlir::TypeAttr::get(convertType(
+            llvm::cast<mlir::TypeAttr>(attr.getValue()).getValue()));
+      };
+      llvm::SmallVector<mlir::Attribute> newArgAttrsArray;
+      for (auto argAttrs : argAttrsArray) {
+        llvm::SmallVector<mlir::NamedAttribute> convertedAttrs;
+        for (const mlir::NamedAttribute &attr :
+             llvm::cast<mlir::DictionaryAttr>(argAttrs)) {
+          if (attr.getName().getValue() ==
+              mlir::LLVM::LLVMDialect::getByValAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                mlir::LLVM::LLVMDialect::getByValAttrName(),
+                convertTypeAttr(attr)));
+          } else if (attr.getName().getValue() ==
+                     mlir::LLVM::LLVMDialect::getStructRetAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                mlir::LLVM::LLVMDialect::getStructRetAttrName(),
+                convertTypeAttr(attr)));
+          } else {
+            convertedAttrs.push_back(attr);
+          }
+        }
+        newArgAttrsArray.emplace_back(
+            mlir::DictionaryAttr::get(rewriter.getContext(), convertedAttrs));
+      }
+      llvmCall.setArgAttrsAttr(rewriter.getArrayAttr(newArgAttrsArray));
+    }
     if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
       llvmCall.setResAttrsAttr(resAttrs);
     return mlir::success();
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index c099a08ffd30a..5c9da0321bcc4 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -534,19 +534,44 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
       fir::CallOp newCall;
       if (callOp.getCallee()) {
-        newCall =
-            rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
+        newCall = rewriter->create<fir::CallOp>(loc, *callOp.getCallee(),
+                                                newResTys, newOpers);
       } else {
-        // TODO: llvm dialect must be updated to propagate argument on
-        // attributes for indirect calls. See:
-        // https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
-        if (hasByValOrSRetArgs(newInTyAndAttrs))
-          TODO(loc,
-               "passing argument or result on the stack in indirect calls");
         newOpers[0].setType(mlir::FunctionType::get(
             callOp.getContext(),
             mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
-        newCall = rewriter->create<A>(loc, newResTys, newOpers);
+        newCall = rewriter->create<fir::CallOp>(loc, newResTys, newOpers);
+        // Set ABI argument attributes on call operation since they are not
+        // accessible via a FuncOp in indirect calls.
+        if (hasByValOrSRetArgs(newInTyAndAttrs)) {
+          llvm::SmallVector<mlir::Attribute> argAttrsArray;
+          for (const auto &arg :
+               llvm::ArrayRef<fir::CodeGenSpecifics::TypeAndAttr>(
+                   newInTyAndAttrs)
+                   .drop_front(dropFront)) {
+            mlir::NamedAttrList argAttrs;
+            const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
+            if (attr.isByVal()) {
+              mlir::Type elemType =
+                  fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
+              argAttrs.set(mlir::LLVM::LLVMDialect::getByValAttrName(),
+                           mlir::TypeAttr::get(elemType));
+            } else if (attr.isSRet()) {
+              mlir::Type elemType =
+                  fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
+              argAttrs.set(mlir::LLVM::LLVMDialect::getStructRetAttrName(),
+                           mlir::TypeAttr::get(elemType));
+              if (auto align = attr.getAlignment()) {
+                argAttrs.set(mlir::LLVM::LLVMDialect::getAlignAttrName(),
+                             rewriter->getIntegerAttr(
+                                 rewriter->getIntegerType(32), align));
+              }
+            }
+            argAttrsArray.emplace_back(
+                argAttrs.getDictionary(rewriter->getContext()));
+          }
+          newCall.setArgAttrsAttr(rewriter->getArrayAttr(argAttrsArray));
+        }
       }
       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
       if (wrap)
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index c11cfd5d5faa1..8727c0ab08e70 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2871,3 +2871,17 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
   return %0 : i16
 }
+
+// CHECK-LABEL: @test_byval
+func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
+  fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+  return
+}
+
+// CHECK-LABEL: @test_sret
+func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
+  fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+  return
+}
diff --git a/flang/test/Fir/target-rewrite-indirect-calls.fir b/flang/test/Fir/target-rewrite-indirect-calls.fir
new file mode 100644
index 0000000000000..dbb3d0823520c
--- /dev/null
+++ b/flang/test/Fir/target-rewrite-indirect-calls.fir
@@ -0,0 +1,22 @@
+// Test that ABI attributes are set in indirect calls to BIND(C) functions.
+// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+
+func.func @test(%arg0: () -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  %0 = fir.load %arg1 : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+  %1 = fir.convert %arg0 : (() -> ()) -> ((!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ())
+  fir.call %1(%0, %arg2) proc_attrs<bind_c> : (!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ()
+  return
+}
+// CHECK-LABEL:   func.func @test(
+// CHECK-SAME:                    %[[VAL_0:.*]]: () -> (),
+// CHECK-SAME:                    %[[VAL_1:.*]]: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>,
+// CHECK-SAME:                    %[[VAL_2:.*]]: f64) {
+// CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+// CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_0]] : (() -> ()) -> ((!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> ())
+// CHECK:           %[[VAL_5:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.type<t{a:!fir.array<5xf64>}>
+// CHECK:           fir.store %[[VAL_3]] to %[[VAL_6]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+// CHECK:           fir.call %[[VAL_4]](%[[VAL_6]], %[[VAL_2]]) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+// CHECK:           llvm.intr.stackrestore %[[VAL_5]] : !llvm.ptr
+// CHECK:           return
+// CHECK:         }

@eugeneepshteyn
Copy link
Contributor

Is there an example that shows how this change affects final LLVM IR?

@jeanPerier
Copy link
Contributor Author

Is there an example that shows how this change affects final LLVM IR?

In the MLIR tests that are testing the translation of LLVM MLIR dialect to LLVM IR, there are tests for argument attributes

https://github.com/llvm/llvm-project/blob/99e1308c41b24e2422324d68be28e5370196e5d6/mlir/test/Target/LLVMIR/call-argument-attributes.mlir

So usually, it is better to limit flang unit tests to test the steps implemented in flang. However, given there are several steps involved here, and given the MLIR tests to not test sret, it sounds reasonable to add a small integration test from Fortran to LLVM IR.

I added one.

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@jeanPerier jeanPerier merged commit 5836d91 into llvm:main Feb 12, 2025
8 checks passed
@jeanPerier jeanPerier deleted the fir-call-attrs branch February 12, 2025 16:31
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
Last piece that implements the TODO for sret and byval setting on
indirect calls.

This includes a fix to the codegen last patch. I thought types in in
type attributes were automatically converted in dialect conversion
passes, but that is not the case. The sret and byval type needs to be
converted to llvm types in codegen (mlir FuncOp conversion is doing a
similar conversion).
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Last piece that implements the TODO for sret and byval setting on
indirect calls.

This includes a fix to the codegen last patch. I thought types in in
type attributes were automatically converted in dialect conversion
passes, but that is not the case. The sret and byval type needs to be
converted to llvm types in codegen (mlir FuncOp conversion is doing a
similar conversion).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants