Skip to content

Commit

Permalink
Fixed conversion flag
Browse files Browse the repository at this point in the history
  • Loading branch information
Berke-Ates committed Sep 28, 2023
1 parent 750ea7d commit ca3896b
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 51 deletions.
2 changes: 1 addition & 1 deletion include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name GenericToSDFG)
add_public_tablegen_target(MLIRGenericToSDFGPassIncGen)

target_sources(SOURCE_FILES_H PRIVATE PassDetail.h Passes.h)
target_sources(SOURCE_FILES_H PRIVATE Passes.h)
23 changes: 0 additions & 23 deletions include/SDFG/Conversion/GenericToSDFG/PassDetail.h

This file was deleted.

9 changes: 3 additions & 6 deletions include/SDFG/Conversion/GenericToSDFG/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@

namespace mlir::sdfg::conversion {

/// Creates a generic to sdfg converting pass
std::unique_ptr<Pass> createGenericToSDFGPass(StringRef getMainFuncName = "");

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
/// Generate the code for declaring passes.
#define GEN_PASS_DECL
#include "SDFG/Conversion/GenericToSDFG/Passes.h.inc"

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
1 change: 0 additions & 1 deletion include/SDFG/Conversion/GenericToSDFG/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ include "SDFG/Dialect/Dialect.td"
/// Define generic to SDFG pass.
def GenericToSDFGPass : Pass<"convert-to-sdfg", "ModuleOp"> {
let summary = "Convert SCF, Arith, Math and Memref dialect to SDFG dialect";
let constructor = "mlir::sdfg::conversion::createGenericToSDFGPass()";
let dependentDialects = ["mlir::sdfg::SDFGDialect"];
let options = [
Option<"mainFuncName", "main-func-name", "std::string", /*default=*/"",
Expand Down
41 changes: 21 additions & 20 deletions lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

/// This file defines a converter from builtin dialects to the SDFG dialect.

#include "SDFG/Conversion/GenericToSDFG/PassDetail.h"
#include "SDFG/Conversion/GenericToSDFG/Passes.h"
#include "SDFG/Dialect/Dialect.h"
#include "SDFG/Utils/Utils.h"
Expand All @@ -17,6 +16,11 @@ using namespace mlir;
using namespace sdfg;
using namespace conversion;

namespace mlir::sdfg::conversion {
#define GEN_PASS_DEF_GENERICTOSDFGPASS
#include "SDFG/Conversion/GenericToSDFG/Passes.h.inc"
} // namespace mlir::sdfg::conversion

//===----------------------------------------------------------------------===//
// Target & Type Converter
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -245,6 +249,9 @@ static Value getTransientValue(Value val) {
// Func Patterns
//===----------------------------------------------------------------------===//

// FIXME: Find a cleaner way to pass the name to the pattern.
llvm::StringRef mainfuncName;

/// Converts a func::FuncOp to a SDFG node.
class FuncToSDFG : public OpConversionPattern<func::FuncOp> {
public:
Expand All @@ -258,15 +265,14 @@ class FuncToSDFG : public OpConversionPattern<func::FuncOp> {
return success();
}

// TODO: Should be passed by a subflag
if (!op.getName().equals("main")) {
if (!op.getName().equals(mainfuncName)) {
// NOTE: The nested SDFG is created at the call operation conversion
rewriter.eraseOp(op);
return success();
}

// HACK: Replaces the print_array call with returning arrays (PolybenchC)
if (op.getName().equals("main")) {
if (op.getName().equals(mainfuncName)) {
for (int i = op.getNumArguments() - 1; i >= 0; --i)
if (op.getArgument(i).getUses().empty())
op.eraseArgument(i);
Expand Down Expand Up @@ -1750,13 +1756,11 @@ void populateGenericToSDFGConversionPatterns(RewritePatternSet &patterns,

namespace {
struct GenericToSDFGPass
: public sdfg::conversion::GenericToSDFGPassBase<GenericToSDFGPass> {
std::string mainFuncName;

: public sdfg::conversion::impl::GenericToSDFGPassBase<GenericToSDFGPass> {
GenericToSDFGPass() = default;

explicit GenericToSDFGPass(StringRef mainFuncName)
: mainFuncName(mainFuncName.str()) {}
GenericToSDFGPass(const sdfg::conversion::GenericToSDFGPassOptions &options)
: sdfg::conversion::impl::GenericToSDFGPassBase<GenericToSDFGPass>(
options) {}

void runOnOperation() override;
};
Expand Down Expand Up @@ -1803,10 +1807,13 @@ llvm::Optional<std::string> getMainFunctionName(ModuleOp moduleOp) {
void GenericToSDFGPass::runOnOperation() {
ModuleOp module = getOperation();

// FIXME: Find a way to get func name via CLI instead of inferring
llvm::Optional<std::string> mainFuncNameOpt = getMainFunctionName(module);
if (mainFuncNameOpt)
mainFuncName = *mainFuncNameOpt;
// Get name of the main function to convert.
if (this->mainFuncName.empty()) {
llvm::Optional<std::string> mainFuncNameOpt = getMainFunctionName(module);
if (mainFuncNameOpt.has_value())
this->mainFuncName = mainFuncNameOpt.value();
}
mainfuncName = this->mainFuncName;

// Clear all attributes
for (NamedAttribute a : module->getAttrs())
Expand All @@ -1821,9 +1828,3 @@ void GenericToSDFGPass::runOnOperation() {
if (applyFullConversion(module, target, std::move(patterns)).failed())
signalPassFailure();
}

/// Returns a unique pointer to this pass.
std::unique_ptr<Pass>
conversion::createGenericToSDFGPass(StringRef getMainFuncName) {
return std::make_unique<GenericToSDFGPass>(getMainFuncName);
}

0 comments on commit ca3896b

Please sign in to comment.