Skip to content

Commit

Permalink
fixup! update codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
kumarak committed Nov 16, 2024
1 parent 216defc commit ef5ca84
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 90 deletions.
24 changes: 12 additions & 12 deletions include/patchestry/AST/ASTConsumer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ namespace patchestry::ast {
{
public:
explicit PcodeASTConsumer(
clang::CompilerInstance &ci, Program &prog, llvm::raw_ostream &out,
llvm::raw_ostream &ast_out
clang::CompilerInstance &ci, Program &prog, std::string &outfile
)
: program(prog)
, ci(ci)
, out(out)
, ast_out(ast_out)
, codegen(std::make_unique< CodeGenerator >())
, outfile(outfile)
, codegen(std::make_unique< CodeGenerator >(ci))
, type_builder(std::make_unique< TypeBuilder >(ci.getASTContext())) {}

void HandleTranslationUnit(clang::ASTContext &ctx) override;

private:
void set_sema_context(clang::DeclContext *dc);

void write_to_file(void);

void create_globals(clang::ASTContext &ctx, VariableMap &serialized_variables);

void create_functions(
Expand All @@ -75,10 +75,13 @@ namespace patchestry::ast {
clang::FunctionDecl *
create_function_definition(clang::ASTContext &ctx, const Function &function);

std::vector< clang::Stmt * >
create_function_body(clang::ASTContext &ctx, const Function &function);
std::vector< clang::Stmt * > create_function_body(
clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const Function &function
);

void create_label_for_basic_blocks(clang::ASTContext &ctx, const Function &function);
void create_label_for_basic_blocks(
clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const Function &function
);

std::vector< clang::Stmt * > create_basic_block(
clang::ASTContext &ctx, const Function &function, const BasicBlock &block
Expand Down Expand Up @@ -314,14 +317,11 @@ namespace patchestry::ast {

std::reference_wrapper< Program > program;
std::reference_wrapper< clang::CompilerInstance > ci;
std::reference_wrapper< llvm::raw_ostream > out;
std::reference_wrapper< llvm::raw_ostream > ast_out;

std::string outfile;
std::unique_ptr< CodeGenerator > codegen;

std::unique_ptr< TypeBuilder > type_builder;

std::unordered_map< std::string, clang::Decl * > incomplete_definition;
std::unordered_map< std::string, clang::FunctionDecl * > function_declarations;

/* Map of basic block label decls and stmt for creating branch instructions */
Expand Down
16 changes: 11 additions & 5 deletions include/patchestry/AST/Codegen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#pragma once

#include <clang/AST/ASTContext.h>
#include <clang/Frontend/CompilerInstance.h>
#include <vast/Frontend/FrontendAction.hpp>
#include <vast/Frontend/Options.hpp>

namespace llvm {
class raw_fd_ostream;
Expand All @@ -17,16 +20,19 @@ namespace patchestry::ast {
class CodeGenerator
{
public:
CodeGenerator() = default;
explicit CodeGenerator(clang::CompilerInstance &ci) : opts(vast::cc::options(ci)) {}

CodeGenerator(const CodeGenerator &) = default;
CodeGenerator &operator=(const CodeGenerator &) = default;
CodeGenerator(CodeGenerator &&) noexcept = default;
CodeGenerator &operator=(CodeGenerator &&) noexcept = default;
CodeGenerator(const CodeGenerator &) = delete;
CodeGenerator &operator=(const CodeGenerator &) = delete;
CodeGenerator(CodeGenerator &&) noexcept = delete;
CodeGenerator &operator=(CodeGenerator &&) noexcept = delete;

virtual ~CodeGenerator() {}

void generate_source_ir(clang::ASTContext &ctx, llvm::raw_fd_ostream &os);

private:
vast::cc::action_options opts;
};

} // namespace patchestry::ast
34 changes: 25 additions & 9 deletions lib/patchestry/AST/ASTConsumer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,28 @@ namespace patchestry::ast {
);
}

std::error_code ec;
auto out =
std::make_unique< llvm::raw_fd_ostream >(outfile, ec, llvm::sys::fs::OF_Text);

llvm::errs() << "Print AST dump\n";
ctx.getTranslationUnitDecl()->dumpColor();

ctx.getTranslationUnitDecl()->print(out, ctx.getPrintingPolicy(), 0);
ctx.getTranslationUnitDecl()->print(
*llvm::dyn_cast< llvm::raw_ostream >(out), ctx.getPrintingPolicy(), 0
);

llvm::errs() << "Generate mlir\n";
std::error_code ec;
llvm::raw_fd_ostream file_os("/tmp/lifted.mlir", ec);
llvm::raw_fd_ostream file_os(outfile + ".mlir", ec);
codegen->generate_source_ir(ctx, file_os);
}

void PcodeASTConsumer::set_sema_context(clang::DeclContext *dc) {
get_sema().CurContext = dc;
}

void PcodeASTConsumer::write_to_file(void) {}

clang::QualType PcodeASTConsumer::create_function_prototype(
clang::ASTContext &ctx, const FunctionPrototype &proto
) {
Expand Down Expand Up @@ -213,6 +220,7 @@ namespace patchestry::ast {
);

// Add function declaration to tralsation unit
func_decl->setDeclContext(ctx.getTranslationUnitDecl());
ctx.getTranslationUnitDecl()->addDecl(func_decl);

// Set asm label attribute to symbol name
Expand Down Expand Up @@ -271,7 +279,7 @@ namespace patchestry::ast {
auto *func_def = create_function_declaration(ctx, function, true);
if (func_def != nullptr) {
set_sema_context(func_def);
auto body_vec = create_function_body(ctx, function);
auto body_vec = create_function_body(ctx, func_def, function);
func_def->setBody(clang::CompoundStmt::Create(
ctx, body_vec, clang::FPOptionsOverride(), clang::SourceLocation(),
clang::SourceLocation()
Expand All @@ -281,15 +289,16 @@ namespace patchestry::ast {
return func_def;
}

std::vector< clang::Stmt * >
PcodeASTConsumer::create_function_body(clang::ASTContext &ctx, const Function &function) {
std::vector< clang::Stmt * > PcodeASTConsumer::create_function_body(
clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const Function &function
) {
if (function.basic_blocks.empty()) {
llvm::errs() << "Function " << function.name << " doesn't have body\n";
return {};
}

// Create label decl for all basic blocks
create_label_for_basic_blocks(ctx, function);
create_label_for_basic_blocks(ctx, func_decl, function);

std::vector< clang::Stmt * > stmts;

Expand Down Expand Up @@ -330,7 +339,7 @@ namespace patchestry::ast {
}

void PcodeASTConsumer::create_label_for_basic_blocks(
clang::ASTContext &ctx, const Function &function
clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const Function &function
) {
if (function.basic_blocks.empty()) {
llvm::errs() << "Function " << function.name << " does not have any basic block\n";
Expand All @@ -345,9 +354,15 @@ namespace patchestry::ast {
}

auto *label_decl = clang::LabelDecl::Create(
ctx, ctx.getTranslationUnitDecl(), clang::SourceLocation(),
ctx, func_decl, clang::SourceLocation(),
&ctx.Idents.get(label_name_from_key(key))
);

label_decl->setDeclContext(func_decl);
if (clang::DeclContext *dc = label_decl->getLexicalDeclContext()) {
dc->addDecl(label_decl);
}

basic_block_labels.emplace(key, label_decl);
}
}
Expand Down Expand Up @@ -391,6 +406,7 @@ namespace patchestry::ast {
ctx.getTrivialTypeSourceInfo(var_type), clang::SC_Static
);

var_decl->setDeclContext(ctx.getTranslationUnitDecl());
ctx.getTranslationUnitDecl()->addDecl(var_decl);
global_variable_declarations.emplace(variable.key, var_decl);
}
Expand Down
66 changes: 23 additions & 43 deletions lib/patchestry/AST/Codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ VAST_UNRELAX_WARNINGS
#include <vast/Dialect/Dialects.hpp>
#include <vast/Dialect/Meta/MetaAttributes.hpp>

#include <vast/Frontend/FrontendAction.hpp>
#include <vast/Frontend/Options.hpp>

#include <vast/CodeGen/AttrVisitorProxy.hpp>
#include <vast/CodeGen/DefaultCodeGenPolicy.hpp>
#include <vast/CodeGen/DefaultMetaGenerator.hpp>
#include <vast/Util/Common.hpp>
#include <vast/Util/DataLayout.hpp>

Expand All @@ -68,33 +72,34 @@ namespace patchestry::ast {

mlir::DialectRegistry registry;
MLIRRegistryInitializer registry_initializer;
mutable mlir::MLIRContext context;
mutable mlir::MLIRContext ctx;

public:
explicit MLIRInitializer(int);

inline mlir::MLIRContext &Context(void) const noexcept { return context; }
inline mlir::MLIRContext &context(void) const noexcept { return ctx; }

~MLIRInitializer(void);
};

MLIRInitializer::MLIRInitializer(int)
: registry()
, registry_initializer(registry)
, context(registry, mlir::MLIRContext::Threading::ENABLED) {
context.disableMultithreading();
context.loadAllAvailableDialects();
context.enableMultithreading();
, ctx(registry, mlir::MLIRContext::Threading::ENABLED) {
ctx.disableMultithreading();
ctx.loadAllAvailableDialects();
ctx.enableMultithreading();
}

MLIRInitializer::~MLIRInitializer(void) { context.disableMultithreading(); }
MLIRInitializer::~MLIRInitializer(void) { ctx.disableMultithreading(); }

static const MLIRInitializer kMLIR(0);

// Custom meta generator that disable all location setting to unknown
class MetaGenerator final : public vast::cg::meta_generator
{
// MLIR context for generating the MLIR location from source location
mlir::MLIRContext *const mctx;
mlir::MLIRContext *mctx;

mlir::Location unknown_location;

Expand All @@ -121,61 +126,36 @@ namespace patchestry::ast {
}
};

class CodeGenPolicy final : public vast::cg::codegen_policy
{
public:
CodeGenPolicy(void) = default;

virtual ~CodeGenPolicy(void) = default;

bool emit_strict_function_return(const vast::cg::clang_function *) const final {
return false;
};

enum vast::cg::missing_return_policy
get_missing_return_policy(const vast::cg::clang_function *) const final {
return vast::cg::missing_return_policy::emit_trap;
}

bool SkipDeclBody(const void *decl) const { return false; }

bool skip_function_body(const vast::cg::clang_function *decl) const final {
return SkipDeclBody(decl);
}

bool skip_global_initializer(const vast::cg::clang_var_decl *decl) const final {
return SkipDeclBody(decl);
}
};

static std::optional< vast::owning_mlir_module_ref > create_module(clang::ASTContext &ctx) {
auto &mctx = kMLIR.Context();
static std::optional< vast::owning_mlir_module_ref >
create_module(clang::ASTContext &ctx, vast::cc::action_options &opts) {
auto &mctx = kMLIR.context();
auto bld = vast::cg::mk_codegen_builder(mctx);
auto mg = std::make_shared< MetaGenerator >(mctx);
auto mg = std::make_shared< vast::cg::default_meta_gen >(&ctx, &mctx);
auto sg =
std::make_shared< vast::cg::default_symbol_generator >(ctx.createMangleContext());
auto cp = std::make_shared< CodeGenPolicy >();
auto cp = std::make_shared< vast::cg::default_policy >(opts);
using vast::cg::as_node;
using vast::cg::as_node_with_list_ref;

auto visitors = std::make_shared< vast::cg::visitor_list >()
| as_node_with_list_ref< vast::cg::attr_visitor_proxy >()
| as_node< vast::cg::type_caching_proxy >()
| as_node_with_list_ref< vast::cg::default_visitor >(mctx, ctx, *bld, mg, sg, cp)
| as_node_with_list_ref< vast::cg::unsup_visitor >(mctx, *bld, mg)
| as_node< vast::cg::fallthrough_visitor >();

vast::cg::driver driver(ctx, mctx, std::move(bld), visitors);
driver.enable_verifier(true);
driver.emit(const_cast< clang::Decl * >(
clang::dyn_cast< clang::Decl >(ctx.getTranslationUnitDecl())
));
for (auto &decl : ctx.getTranslationUnitDecl()->noload_decls()) {
driver.emit(clang::dyn_cast< clang::Decl >(decl));
}

driver.finalize();
return std::make_optional(driver.freeze());
}

void CodeGenerator::generate_source_ir(clang::ASTContext &ctx, llvm::raw_fd_ostream &os) {
auto mod = create_module(ctx);
auto mod = create_module(ctx, opts);
auto flags = mlir::OpPrintingFlags();
flags.enableDebugInfo(true, false);
(*mod)->print(os, flags);
Expand Down
5 changes: 2 additions & 3 deletions lib/patchestry/AST/OperationStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ namespace patchestry::ast {
);

auto *var_decl = create_variable_decl(
ctx, ctx.getTranslationUnitDecl(), op.name, type_iter->second,
ctx, get_sema().CurContext, op.name, type_iter->second,
source_location_from_key(ctx, op.key)
);

Expand All @@ -109,7 +109,7 @@ namespace patchestry::ast {
);

auto *var_decl = create_variable_decl(
ctx, ctx.getTranslationUnitDecl(), op.name, type_iter->second,
ctx, get_sema().CurContext, op.name, type_iter->second,
source_location_from_key(ctx, op.key)
);

Expand Down Expand Up @@ -552,7 +552,6 @@ namespace patchestry::ast {
auto *lhs = create_varnode(ctx, function, op.inputs[0]);
auto *rhs = create_varnode(ctx, function, op.inputs[1]);

get_sema().CurContext = ctx.getTranslationUnitDecl();
auto result = get_sema().CreateBuiltinBinOp(
source_location_from_key(ctx, op.key), clang::BinaryOperatorKind::BO_EQ,
clang::dyn_cast< clang::Expr >(lhs), clang::dyn_cast< clang::Expr >(rhs)
Expand Down
9 changes: 8 additions & 1 deletion lib/patchestry/AST/TypeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ namespace patchestry::ast {
&identifier, tinfo
);

typedef_decl->setDeclContext(ctx.getTranslationUnitDecl());
ctx.getTranslationUnitDecl()->addDecl(typedef_decl);

return ctx.getTypedefType(typedef_decl);
Expand Down Expand Up @@ -185,6 +186,8 @@ namespace patchestry::ast {
);
record_decl->addDecl(field_decl);
}

record_decl->setDeclContext(ctx.getTranslationUnitDecl());
ctx.getTranslationUnitDecl()->addDecl(record_decl);
}

Expand All @@ -208,6 +211,8 @@ namespace patchestry::ast {
source_location_from_key(ctx, composite_type.key),
&ctx.Idents.get(composite_type.name)
);

decl->setDeclContext(ctx.getTranslationUnitDecl());
ctx.getTranslationUnitDecl()->addDecl(decl);
missing_type_definition.emplace(composite_type.key, decl);
return ctx.getRecordType(decl);
Expand All @@ -221,6 +226,8 @@ namespace patchestry::ast {
source_location_from_key(ctx, enum_type.key), &identifier, nullptr, true, false,
false
);

enum_decl->setDeclContext(ctx.getTranslationUnitDecl());
ctx.getTranslationUnitDecl()->addDecl(enum_decl);
return ctx.getEnumType(enum_decl);
}
Expand All @@ -245,7 +252,7 @@ namespace patchestry::ast {
ctx, ctx.getTranslationUnitDecl(), clang::SourceLocation(), clang::SourceLocation(),
&ctx.Idents.get(undefined_type.name), ctx.getTrivialTypeSourceInfo(base_type)
);

typedef_decl->setDeclContext(ctx.getTranslationUnitDecl());
ctx.getTranslationUnitDecl()->addDecl(typedef_decl);
return ctx.getTypedefType(typedef_decl);
}
Expand Down
Loading

0 comments on commit ef5ca84

Please sign in to comment.