Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Conversion Flag #31

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name GenericToSDFG)
add_public_tablegen_target(MLIRGenericToSDFGPassIncGen)

target_sources(SOURCE_FILES_H PRIVATE PassDetail.h Passes.h)
target_sources(SOURCE_FILES_H PRIVATE Passes.h)
23 changes: 0 additions & 23 deletions include/SDFG/Conversion/GenericToSDFG/PassDetail.h

This file was deleted.

9 changes: 3 additions & 6 deletions include/SDFG/Conversion/GenericToSDFG/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@

namespace mlir::sdfg::conversion {

/// Creates a generic to sdfg converting pass
std::unique_ptr<Pass> createGenericToSDFGPass(StringRef getMainFuncName = "");

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
/// Generate the code for declaring passes.
#define GEN_PASS_DECL
#include "SDFG/Conversion/GenericToSDFG/Passes.h.inc"

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
1 change: 0 additions & 1 deletion include/SDFG/Conversion/GenericToSDFG/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ include "SDFG/Dialect/Dialect.td"
/// Define generic to SDFG pass.
def GenericToSDFGPass : Pass<"convert-to-sdfg", "ModuleOp"> {
let summary = "Convert SCF, Arith, Math and Memref dialect to SDFG dialect";
let constructor = "mlir::sdfg::conversion::createGenericToSDFGPass()";
let dependentDialects = ["mlir::sdfg::SDFGDialect"];
let options = [
Option<"mainFuncName", "main-func-name", "std::string", /*default=*/"",
Expand Down
41 changes: 21 additions & 20 deletions lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

/// This file defines a converter from builtin dialects to the SDFG dialect.

#include "SDFG/Conversion/GenericToSDFG/PassDetail.h"
#include "SDFG/Conversion/GenericToSDFG/Passes.h"
#include "SDFG/Dialect/Dialect.h"
#include "SDFG/Utils/Utils.h"
Expand All @@ -17,6 +16,11 @@ using namespace mlir;
using namespace sdfg;
using namespace conversion;

namespace mlir::sdfg::conversion {
#define GEN_PASS_DEF_GENERICTOSDFGPASS
#include "SDFG/Conversion/GenericToSDFG/Passes.h.inc"
} // namespace mlir::sdfg::conversion

//===----------------------------------------------------------------------===//
// Target & Type Converter
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -245,6 +249,9 @@ static Value getTransientValue(Value val) {
// Func Patterns
//===----------------------------------------------------------------------===//

// FIXME: Find a cleaner way to pass the name to the pattern.
llvm::StringRef mainfuncName;

/// Converts a func::FuncOp to a SDFG node.
class FuncToSDFG : public OpConversionPattern<func::FuncOp> {
public:
Expand All @@ -258,15 +265,14 @@ class FuncToSDFG : public OpConversionPattern<func::FuncOp> {
return success();
}

// TODO: Should be passed by a subflag
if (!op.getName().equals("main")) {
if (!op.getName().equals(mainfuncName)) {
// NOTE: The nested SDFG is created at the call operation conversion
rewriter.eraseOp(op);
return success();
}

// HACK: Replaces the print_array call with returning arrays (PolybenchC)
if (op.getName().equals("main")) {
if (op.getName().equals(mainfuncName)) {
for (int i = op.getNumArguments() - 1; i >= 0; --i)
if (op.getArgument(i).getUses().empty())
op.eraseArgument(i);
Expand Down Expand Up @@ -1750,13 +1756,11 @@ void populateGenericToSDFGConversionPatterns(RewritePatternSet &patterns,

namespace {
struct GenericToSDFGPass
: public sdfg::conversion::GenericToSDFGPassBase<GenericToSDFGPass> {
std::string mainFuncName;

: public sdfg::conversion::impl::GenericToSDFGPassBase<GenericToSDFGPass> {
GenericToSDFGPass() = default;

explicit GenericToSDFGPass(StringRef mainFuncName)
: mainFuncName(mainFuncName.str()) {}
GenericToSDFGPass(const sdfg::conversion::GenericToSDFGPassOptions &options)
: sdfg::conversion::impl::GenericToSDFGPassBase<GenericToSDFGPass>(
options) {}

void runOnOperation() override;
};
Expand Down Expand Up @@ -1803,10 +1807,13 @@ llvm::Optional<std::string> getMainFunctionName(ModuleOp moduleOp) {
void GenericToSDFGPass::runOnOperation() {
ModuleOp module = getOperation();

// FIXME: Find a way to get func name via CLI instead of inferring
llvm::Optional<std::string> mainFuncNameOpt = getMainFunctionName(module);
if (mainFuncNameOpt)
mainFuncName = *mainFuncNameOpt;
// Get name of the main function to convert.
if (this->mainFuncName.empty()) {
llvm::Optional<std::string> mainFuncNameOpt = getMainFunctionName(module);
if (mainFuncNameOpt.has_value())
this->mainFuncName = mainFuncNameOpt.value();
}
mainfuncName = this->mainFuncName;

// Clear all attributes
for (NamedAttribute a : module->getAttrs())
Expand All @@ -1821,9 +1828,3 @@ void GenericToSDFGPass::runOnOperation() {
if (applyFullConversion(module, target, std::move(patterns)).failed())
signalPassFailure();
}

/// Returns a unique pointer to this pass.
std::unique_ptr<Pass>
conversion::createGenericToSDFGPass(StringRef getMainFuncName) {
return std::make_unique<GenericToSDFGPass>(getMainFuncName);
}