Skip to content

Commit

Permalink
[runtime] support llvm bitcode control
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 committed Aug 20, 2024
1 parent c768cbf commit 1646ddb
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 2 deletions.
10 changes: 10 additions & 0 deletions compiler/include/byteir-c/Translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"

enum ByteIRHead {
TOTAL_HEAD_BYTES = 12,
MAGIC_NUMBER = 0x30dc4790,
MAGIC_NUMBER_BYTES = 4,
MAJOR_VERSION = 1,
MAJOR_VERSION_BYTES = 4,
MINOR_VERSION = 0,
MINOR_VERSION_BYTES = 4,
};

#ifdef __cplusplus
extern "C" {
#endif
Expand Down
15 changes: 15 additions & 0 deletions compiler/lib/CAPI/Translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/ToolOutputFile.h"
#include <cstdlib>
#include <string>
Expand Down Expand Up @@ -76,6 +77,20 @@ bool byteirTranslateToLLVMBC(MlirModule module, MlirStringRef outputFile) {
llvm::errs() << "failed to create output file: " << unwrap(outputFile);
return false;
}

// Insert head before WriteBitCodeToFile.
SmallVector<char, 0> head(12);
auto writeInt32ToBuffer = [](uint32_t value, SmallVectorImpl<char> &buffer,
unsigned &position) {
llvm::support::endian::write32le(&buffer[position], value);
position += 4;
};
unsigned position = 0;
writeInt32ToBuffer(MAGIC_NUMBER, head, position);
writeInt32ToBuffer(MAJOR_VERSION, head, position);
writeInt32ToBuffer(MINOR_VERSION, head, position);
fout.write((char *)&head.front(), head.size());

llvm::WriteBitcodeToFile(*llvmModule, fout);
return true;
}
Expand Down
83 changes: 81 additions & 2 deletions runtime/lib/backends/cpu/device/llvm/jit.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "brt/backends/cpu/device/llvm/jit.h"
#include "brt/core/common/common.h"
#include "brt/core/ir/engine_util.h"
#include "byteir-c/Translation.h"
#include "llvm/Analysis/CGSCCPassManager.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/ExecutionEngine/JITEventListener.h"
Expand All @@ -32,6 +33,7 @@
#include "llvm/IRReader/IRReader.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"

Expand Down Expand Up @@ -404,18 +406,95 @@ common::Status LLVMJITImpl::LoadTSM(llvm::orc::ThreadSafeModule &&tsm) {
return LLVMErrorToBRTStatus(std::move(err), "Load TSM failed");
}

bool cleanBrtFile(llvm::MemoryBufferRef brtFile,
const std::string tmpFilePath) {
std::error_code ec;
llvm::raw_fd_ostream fout(tmpFilePath, ec);
if (ec) {
llvm::errs() << "failed to create temporary bitcode file: " << tmpFilePath;
return false;
}

fout.write(brtFile.getBufferStart() + TOTAL_HEAD_BYTES,
brtFile.getBufferSize() - TOTAL_HEAD_BYTES);
return true;
}

common::Status processBrtFile(const std::string &path,
std::string &newFilePath) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(path);
if (std::error_code EC = FileOrErr.getError()) {
return common::Status(common::StatusCategory::BRT, common::StatusCode::FAIL,
"Failed to open the input file: " + path);
}
llvm::MemoryBufferRef brtFile = FileOrErr.get()->getMemBufferRef();
const char *bufferData = brtFile.getBufferStart();

// Check magic number.
auto brtMagicNumber = MAGIC_NUMBER;
int hasBrtMagicNumber =
std::memcmp(bufferData, &brtMagicNumber, MAGIC_NUMBER_BYTES);
if (hasBrtMagicNumber) {
newFilePath = path;
return common::Status::OK();
}

printf("Find magic number.\n");
// Check major number.
std::string tmpFilePath = path + ".tmp";
int64_t majorVersionData = 0;
std::memcpy(&majorVersionData, bufferData + MAGIC_NUMBER_BYTES,
MAJOR_VERSION_BYTES);
if (majorVersionData < MAJOR_VERSION) {
printf("Find valid major version number.\n");
if (cleanBrtFile(brtFile, tmpFilePath)) {
newFilePath = tmpFilePath;
}
return common::Status::OK();
} else if (majorVersionData > MAJOR_VERSION) {
return common::Status(common::StatusCategory::BRT, common::StatusCode::FAIL,
"The major version is larger than current llvm "
"version. Stop to load it...");
}

// Compare minor version number.
int64_t minorVersionData = 0;
std::memcpy(&minorVersionData,
bufferData + MAGIC_NUMBER_BYTES + MAJOR_VERSION_BYTES,
MINOR_VERSION_BYTES);
if (minorVersionData <= MINOR_VERSION) {
printf("Find valid major version number.\n");
if (cleanBrtFile(brtFile, tmpFilePath)) {
newFilePath = tmpFilePath;
}
return common::Status::OK();
}

return common::Status(common::StatusCategory::BRT, common::StatusCode::FAIL,
"The minor version is larger than current llvm "
"version. Stop to load it...");
}

common::Status LLVMJITImpl::ParseIRFile(const std::string &path) {
auto ctx = std::make_unique<llvm::LLVMContext>();
llvm::SMDiagnostic err;
auto mod = llvm::parseIRFile(path, err, *ctx);

std::string newFilePath;
auto status = processBrtFile(path, newFilePath);
if (!status.IsOK()) {
return status;
}

auto mod = llvm::parseIRFile(newFilePath, err, *ctx);
if (!mod) {
std::string buf;
llvm::raw_string_ostream OS(buf);
err.print("parse-ir-file", OS);
return common::Status(common::StatusCategory::BRT, common::StatusCode::FAIL,
"Parse LLVM module failed : " + buf);
}
mod->setModuleIdentifier(path);
mod->setModuleIdentifier(newFilePath);

return LoadTSM({std::move(mod), std::move(ctx)});
}
Expand Down
26 changes: 26 additions & 0 deletions runtime/test/backends/cpu/device/llvm_jit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ using namespace brt::cpu;
using namespace std;

static std::string test_file_add = "test/test_files/LLJIT/add.ll";
static std::string test_file_scatter =
"test/test_files/LLJIT/scatter.brt_head.ll.bc";
static std::string test_file_typecvt = "test/test_files/LLJIT/typecvt.ll";
static std::string test_file_tanh = "test/test_files/LLJIT/tanh.ll";
static std::string test_file_transpose_32_64_64 =
Expand Down Expand Up @@ -67,6 +69,30 @@ TypecvtKernelF32ToF16(const void *src_, void *dst_, const size_t N) {
}
} // namespace

// Test parsing bitcode file with brt head.
TEST(LLVMJITTest, Scatter) {
auto llvmjit = LLVMJIT::Create();
ASSERT_TRUE(llvmjit->LoadFromFile(test_file_scatter).IsOK());

std::vector<int64_t> input_shape{6, 8};
std::vector<int64_t> src_shape{6, 1};
std::vector<float> input_buf(48, 0);
std::vector<float> src_buf(6, 1);
MLIREngineMemRefDescriptor input(input_buf.data(), input_shape),
src(src_buf.data(), src_shape);

{
void *fn;
ASSERT_TRUE(llvmjit->Lookup("_mlir_ciface_memref_copy_kernel", &fn).IsOK());
(*reinterpret_cast<void (*)(void *, void *)>(fn))(src.GetMemrefPtr(),
input.GetMemrefPtr());

for (int i = 0; i < 6; ++i) {
ASSERT_TRUE(input_buf[i * 8] == 1);
}
}
}

TEST(LLVMJITTest, ADD) {
auto llvmjit = LLVMJIT::Create();
ASSERT_TRUE(llvmjit->RegisterSymbol("print", reinterpret_cast<void *>(&print))
Expand Down

0 comments on commit 1646ddb

Please sign in to comment.