diff --git a/include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt b/include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt index e7e107e9..414165f0 100644 --- a/include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt +++ b/include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt @@ -4,4 +4,4 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name GenericToSDFG) add_public_tablegen_target(MLIRGenericToSDFGPassIncGen) -target_sources(SOURCE_FILES_H PRIVATE PassDetail.h Passes.h) +target_sources(SOURCE_FILES_H PRIVATE Passes.h) diff --git a/include/SDFG/Conversion/GenericToSDFG/PassDetail.h b/include/SDFG/Conversion/GenericToSDFG/PassDetail.h deleted file mode 100644 index 87d669bf..00000000 --- a/include/SDFG/Conversion/GenericToSDFG/PassDetail.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich - -/// Header for Generic to SDFG conversion pass details. - -#ifndef SDFG_Conversion_GenericToSDFG_PassDetail_H -#define SDFG_Conversion_GenericToSDFG_PassDetail_H - -#include "SDFG/Dialect/Dialect.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace sdfg { -namespace conversion { - -/// Generate the code for base classes. -#define GEN_PASS_CLASSES -#include "SDFG/Conversion/GenericToSDFG/Passes.h.inc" - -} // namespace conversion -} // namespace sdfg -} // end namespace mlir - -#endif // SDFG_Conversion_GenericToSDFG_PassDetail_H diff --git a/include/SDFG/Conversion/GenericToSDFG/Passes.h b/include/SDFG/Conversion/GenericToSDFG/Passes.h index a2858136..9b6b8b2c 100644 --- a/include/SDFG/Conversion/GenericToSDFG/Passes.h +++ b/include/SDFG/Conversion/GenericToSDFG/Passes.h @@ -9,12 +9,9 @@ namespace mlir::sdfg::conversion { -/// Creates a generic to sdfg converting pass -std::unique_ptr createGenericToSDFGPass(StringRef getMainFuncName = ""); - -//===----------------------------------------------------------------------===// -// Registration -//===----------------------------------------------------------------------===// +/// Generate the code for declaring passes. +#define GEN_PASS_DECL +#include "SDFG/Conversion/GenericToSDFG/Passes.h.inc" /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION diff --git a/include/SDFG/Conversion/GenericToSDFG/Passes.td b/include/SDFG/Conversion/GenericToSDFG/Passes.td index f2290ddd..861de2e8 100644 --- a/include/SDFG/Conversion/GenericToSDFG/Passes.td +++ b/include/SDFG/Conversion/GenericToSDFG/Passes.td @@ -11,7 +11,6 @@ include "SDFG/Dialect/Dialect.td" /// Define generic to SDFG pass. def GenericToSDFGPass : Pass<"convert-to-sdfg", "ModuleOp"> { let summary = "Convert SCF, Arith, Math and Memref dialect to SDFG dialect"; - let constructor = "mlir::sdfg::conversion::createGenericToSDFGPass()"; let dependentDialects = ["mlir::sdfg::SDFGDialect"]; let options = [ Option<"mainFuncName", "main-func-name", "std::string", /*default=*/"", diff --git a/lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp b/lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp index 9cfc2379..da09e164 100644 --- a/lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp +++ b/lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp @@ -2,7 +2,6 @@ /// This file defines a converter from builtin dialects to the SDFG dialect. -#include "SDFG/Conversion/GenericToSDFG/PassDetail.h" #include "SDFG/Conversion/GenericToSDFG/Passes.h" #include "SDFG/Dialect/Dialect.h" #include "SDFG/Utils/Utils.h" @@ -17,6 +16,11 @@ using namespace mlir; using namespace sdfg; using namespace conversion; +namespace mlir::sdfg::conversion { +#define GEN_PASS_DEF_GENERICTOSDFGPASS +#include "SDFG/Conversion/GenericToSDFG/Passes.h.inc" +} // namespace mlir::sdfg::conversion + //===----------------------------------------------------------------------===// // Target & Type Converter //===----------------------------------------------------------------------===// @@ -245,6 +249,9 @@ static Value getTransientValue(Value val) { // Func Patterns //===----------------------------------------------------------------------===// +// FIXME: Find a cleaner way to pass the name to the pattern. +llvm::StringRef mainfuncName; + /// Converts a func::FuncOp to a SDFG node. class FuncToSDFG : public OpConversionPattern { public: @@ -258,15 +265,14 @@ class FuncToSDFG : public OpConversionPattern { return success(); } - // TODO: Should be passed by a subflag - if (!op.getName().equals("main")) { + if (!op.getName().equals(mainfuncName)) { // NOTE: The nested SDFG is created at the call operation conversion rewriter.eraseOp(op); return success(); } // HACK: Replaces the print_array call with returning arrays (PolybenchC) - if (op.getName().equals("main")) { + if (op.getName().equals(mainfuncName)) { for (int i = op.getNumArguments() - 1; i >= 0; --i) if (op.getArgument(i).getUses().empty()) op.eraseArgument(i); @@ -1750,13 +1756,11 @@ void populateGenericToSDFGConversionPatterns(RewritePatternSet &patterns, namespace { struct GenericToSDFGPass - : public sdfg::conversion::GenericToSDFGPassBase { - std::string mainFuncName; - + : public sdfg::conversion::impl::GenericToSDFGPassBase { GenericToSDFGPass() = default; - - explicit GenericToSDFGPass(StringRef mainFuncName) - : mainFuncName(mainFuncName.str()) {} + GenericToSDFGPass(const sdfg::conversion::GenericToSDFGPassOptions &options) + : sdfg::conversion::impl::GenericToSDFGPassBase( + options) {} void runOnOperation() override; }; @@ -1803,10 +1807,13 @@ llvm::Optional getMainFunctionName(ModuleOp moduleOp) { void GenericToSDFGPass::runOnOperation() { ModuleOp module = getOperation(); - // FIXME: Find a way to get func name via CLI instead of inferring - llvm::Optional mainFuncNameOpt = getMainFunctionName(module); - if (mainFuncNameOpt) - mainFuncName = *mainFuncNameOpt; + // Get name of the main function to convert. + if (this->mainFuncName.empty()) { + llvm::Optional mainFuncNameOpt = getMainFunctionName(module); + if (mainFuncNameOpt.has_value()) + this->mainFuncName = mainFuncNameOpt.value(); + } + mainfuncName = this->mainFuncName; // Clear all attributes for (NamedAttribute a : module->getAttrs()) @@ -1821,9 +1828,3 @@ void GenericToSDFGPass::runOnOperation() { if (applyFullConversion(module, target, std::move(patterns)).failed()) signalPassFailure(); } - -/// Returns a unique pointer to this pass. -std::unique_ptr -conversion::createGenericToSDFGPass(StringRef getMainFuncName) { - return std::make_unique(getMainFuncName); -} diff --git a/test/SDFG/Converter/toSDFG/func/entry.mlir b/test/SDFG/Converter/toSDFG/func/entry.mlir new file mode 100644 index 00000000..580882bf --- /dev/null +++ b/test/SDFG/Converter/toSDFG/func/entry.mlir @@ -0,0 +1,11 @@ +// RUN: sdfg-opt --convert-to-sdfg=main-func-name="f2" %s | sdfg-opt | FileCheck %s +// CHECK: arith.addi +func.func private @f1(%arg1: i32, %arg2: i32) -> i32 { + %c0 = arith.subi %arg1, %arg2 : i32 + return %c0 : i32 +} + +func.func private @f2(%arg1: i32, %arg2: i32) -> i32 { + %c0 = arith.addi %arg1, %arg2 : i32 + return %c0 : i32 +} diff --git a/test/SDFG/Converter/toSDFG/func/entry2.mlir b/test/SDFG/Converter/toSDFG/func/entry2.mlir new file mode 100644 index 00000000..3b2e0dbd --- /dev/null +++ b/test/SDFG/Converter/toSDFG/func/entry2.mlir @@ -0,0 +1,11 @@ +// RUN: sdfg-opt --convert-to-sdfg=main-func-name="f1" %s | sdfg-opt | FileCheck %s +// CHECK: arith.subi +func.func private @f1(%arg1: i32, %arg2: i32) -> i32 { + %c0 = arith.subi %arg1, %arg2 : i32 + return %c0 : i32 +} + +func.func private @f2(%arg1: i32, %arg2: i32) -> i32 { + %c0 = arith.addi %arg1, %arg2 : i32 + return %c0 : i32 +}