Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test the return value of omMMapBinaryFile function and terminate the main program elegantly #3002

Merged
merged 13 commits into from
Nov 15, 2024
Merged
21 changes: 14 additions & 7 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath,
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(module.getBody());
std::string fname = llvm::sys::path::filename(filepath).str() + '\0';
fname = (isZOS(module)) ? krnl::e2a_s(fname) : fname;
mlir::StringAttr valueAttr = mlir::StringAttr::get(context, fname);
create.llvm.globalOp(LLVM::LLVMArrayType::get(llvmI8Ty, fname.size()),
/*isConstant=*/true, LLVM::Linkage::Internal,
Expand Down Expand Up @@ -612,15 +613,15 @@ void loadConstantsFromFile(ModuleOp &module,
OpBuilder b(ctx);
MultiDialectBuilder<LLVMBuilder> create(b, loc);

Type llvmI1Ty = IntegerType::get(ctx, 1);
Type llvmI8Ty = IntegerType::get(ctx, 8);
Type llvmI64Ty = IntegerType::get(ctx, 64);
Type llvmI8PtrTy = getPointerType(ctx, llvmI8Ty);
Type llvmVoidTy = LLVM::LLVMVoidType::get(ctx);

// The following function will be emitted inside the IR to load constants from
// file.
std::string loadAllConstantsFuncName = "omLoadConstantsFromFile";
Type llvmFnType = LLVM::LLVMFunctionType::get(llvmVoidTy, {}, false);
Type llvmFnType = LLVM::LLVMFunctionType::get(llvmI1Ty, {}, false);

// If calledByEntryPoint, this function will be called by entry points.
// Otherwise, user program (C/C++/Java/Python) would call this function.
Expand All @@ -629,6 +630,7 @@ void loadConstantsFromFile(ModuleOp &module,
Operation *firstEntryPointOp =
getFirstEntryOpInBlock(module, entryGlobalOps);
assert(firstEntryPointOp && "No entry function exists");
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(firstEntryPointOp);
funcOp = create.llvm.func(
loadAllConstantsFuncName, llvmFnType, /*createUniqueFunc=*/true);
Expand All @@ -646,13 +648,16 @@ void loadConstantsFromFile(ModuleOp &module,
std::find(entryName.begin(), entryName.end(), '\0'), entryName.end());
auto entryFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(entryName);
assert(entryFunc && "Entry function not found");
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(
&entryFunc.getBody().front(), entryFunc.getBody().front().begin());
FlatSymbolRefAttr loadAllConstantsRef = create.llvm.getOrInsertSymbolRef(
module, LLVMBuilder::SymbolPostfix(module, loadAllConstantsFuncName),
llvmVoidTy, {},
llvmI1Ty, {},
/*isVarArg=*/false);
create.llvm.call({}, loadAllConstantsRef, {});
Value retVal = create.llvm.call({llvmI1Ty}, loadAllConstantsRef, {});
equalOrFailed(module, b, loc,
create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)), retVal);
}
} else {
OpBuilder::InsertionGuard guard(b);
Expand Down Expand Up @@ -697,8 +702,11 @@ void loadConstantsFromFile(ModuleOp &module,
// Call a function to mmap the binary file to memory.
Value isleVal = create.llvm.constant(llvmI64Ty, isle);
Value sizeVal = create.llvm.constant(llvmI64Ty, dataSize);
RuntimeAPI::callApi(b, loc, apiRegistry, RuntimeAPI::API::MMAP_BINARY_FILE,
Value retVal = RuntimeAPI::callApi(b, loc, apiRegistry,
RuntimeAPI::API::MMAP_BINARY_FILE,
{packedGlobalPtr, fnameI8Ptr, sizeVal, isleVal});
equalOrReturn(module, b, loc,
create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)), retVal, retVal);

// Now set pointers for constants in the IR
module->walk([&](LLVM::GlobalOp dataGlobalOp) -> WalkResult {
Expand All @@ -725,11 +733,10 @@ void loadConstantsFromFile(ModuleOp &module,
RuntimeAPI::callApi(b, loc, apiRegistry,
RuntimeAPI::API::GET_EXTERNAL_CONSTANT_ADDR,
{dataPtr, packedGlobalPtr, offsetVal});

return WalkResult::advance();
});

create.llvm._return();
create.llvm._return(create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)));
}

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 0 additions & 25 deletions src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,31 +412,6 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
rewriter.getI64Type(), {rewriter.getI64Type()});
}

// Emit code for `IF lhs != rhs THEN return null ELSE do nothing`
void equalOrFailed(ModuleOp &module, PatternRewriter &rewriter, Location loc,
Value lhs, Value rhs, std::string errorMsg = "",
bool appendRHS = true) const {
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (appendRHS)
create.krnl.printf(
StringRef(errorMsg), rhs, rewriter.getI64Type(), true);
else
create.krnl.printf(StringRef(errorMsg + "\n"));
// Set errno.
krnl::emitErrNo(module, rewriter, loc, EINVAL);
// Return NULL.
create.llvm._return(create.llvm.null(getI8PointerType(context)));
});
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this code into KrnlToLLVMHelper.{hpp/cpp} so that it can be reused.

void emitVerificationCodeForInputTensors(ModuleOp &module,
PatternRewriter &rewriter, Location loc,
const RuntimeAPIRegistry &apiRegistry, Value omTensorInputs,
Expand Down
43 changes: 43 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "onnx-mlir/Compiler/OMCompilerRuntimeTypes.h"
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand Down Expand Up @@ -342,5 +343,47 @@ bool isZOS(ModuleOp module) {
return zOS;
}

void equalOrFailed(ModuleOp &module, OpBuilder &rewriter, Location loc,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these calls also be potentially useful if we wanted to fail the inference, say after a NNPA severe failure call?
If so, we may want to see if we would benefit from introducing something at a higher level dialects (say Krnl or SCF) that would then feed into this one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally yes. It will be used anywhere in the main inference when we want to fail the inference. We can introduce a krnl op for this.

Value lhs, Value rhs, std::string errorMsg, bool appendRHS) {
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (!errorMsg.empty()) {
if (appendRHS)
create.krnl.printf(
StringRef(errorMsg), rhs, rewriter.getI64Type(), true);
else
create.krnl.printf(StringRef(errorMsg + "\n"));
}
// Set errno.
emitErrNo(module, rewriter, loc, EINVAL);
// Return NULL.
create.llvm._return(create.llvm.null(getI8PointerType(context)));
});
}

void equalOrReturn(ModuleOp &module, OpBuilder &rewriter, Location loc,
Value lhs, Value rhs, Value retVal, std::string errorMsg) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (!errorMsg.empty())
create.krnl.printf(StringRef(errorMsg + "\n"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we start to use the krnl.printf for more than debugging, I have been often experiencing issues with that call (in debugging situations) where it print too long a string. We have some issues with that call.

I had similar issue in the krnl.PrintTensor and I have alleviated this by having a special %e sequence at the end of the string. For Instrumentation, I think I encoded the string length in an integer flag. All tricks to get around this issue. If someone could try to understand what is happening, it would be great!

// Return retVal.
create.llvm._return(retVal);
});
}

} // namespace krnl
} // namespace onnx_mlir
10 changes: 10 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ std::string e2a_s(std::string e_s);
void emitErrNo(mlir::ModuleOp module, mlir::OpBuilder &builder,
mlir::Location loc, int err);

/// Emit code for `IF lhs != rhs THEN return null ELSE do nothing`.
void equalOrFailed(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
std::string errorMsg = "", bool appendRHS = true);

/// Emit code for `IF lhs != rhs THEN return retVal ELSE do nothing`.
void equalOrReturn(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs, mlir::Value retVal,
std::string errorMsg = "");

/// Creates an LLVM pointer type with the given element type and address space.
/// This function is meant to be used in code supporting both typed and opaque
/// pointers, as it will create an opaque pointer with the given address space
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/KrnlToLLVM/RuntimeAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ RuntimeAPIRegistry::RuntimeAPIRegistry(
: registry() {
MLIRContext *context = module.getContext();
auto voidTy = LLVM::LLVMVoidType::get(context);
Type int1Ty = IntegerType::get(context, 1);
auto int8Ty = IntegerType::get(context, 8);
auto opaquePtrTy = onnx_mlir::krnl::getPointerType(context, int8Ty);
auto opaquePtrPtrTy = onnx_mlir::krnl::getPointerType(context, opaquePtrTy);
Expand All @@ -88,7 +89,7 @@ RuntimeAPIRegistry::RuntimeAPIRegistry(
RuntimeAPI(API::GET_OMT_ARRAY, "omTensorListGetOmtArray", opaquePtrPtrTy, {opaquePtrTy}),
RuntimeAPI(API::PRINT_OMTENSOR, "omTensorPrint", voidTy, {opaquePtrTy, opaquePtrTy}),
RuntimeAPI(API::GET_OMTENSOR_LIST_SIZE, "omTensorListGetSize", int64Ty, {opaquePtrTy}),
RuntimeAPI(API::MMAP_BINARY_FILE, "omMMapBinaryFile", voidTy, {opaquePtrPtrTy, opaquePtrTy, int64Ty, int64Ty}),
RuntimeAPI(API::MMAP_BINARY_FILE, "omMMapBinaryFile", int1Ty, {opaquePtrPtrTy, opaquePtrTy, int64Ty, int64Ty}),
RuntimeAPI(API::GET_EXTERNAL_CONSTANT_ADDR, "omGetExternalConstantAddr", voidTy, {opaquePtrPtrTy, opaquePtrPtrTy, int64Ty}),
};
// clang-format on
Expand Down
120 changes: 54 additions & 66 deletions src/Runtime/OMExternalConstant.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ typedef int make_iso_compilers_happy;

#include <errno.h>
#include <inttypes.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
Expand Down Expand Up @@ -58,88 +59,75 @@ void checkEndianness(const char constPackIsLE) {
///
/// This function is thread-safe.
///
void omMMapBinaryFile(
void **constAddr, char *filename, int64_t size, int64_t isLE) {
checkEndianness(isLE);
char *fname = filename;
#ifdef __MVS__
// Convert the file name to EBCDIC for the open call.
char *tPath = strdup(fname);
if (!tPath) {
fprintf(stderr, "Error while strdup");
return;
}
__a2e_s(tPath);
fname = tPath;
#endif

bool omMMapBinaryFile(
void **constAddr, char *fname, int64_t size, int64_t isLE) {
if (constAddr == NULL) {
perror("Error: null pointer");
return;
fprintf(stderr, "Error: null pointer.");
return false;
}

if (constAddr[0] == NULL) {
char *filePath;
char *basePath = getenv("OM_CONSTANT_PATH");
if (basePath) {
size_t baseLen = strlen(basePath);
size_t fnameLen = strlen(fname);
size_t sepLen = strlen(DIR_SEPARATOR);
size_t filePathLen = baseLen + sepLen + fnameLen;
filePath = (char *)malloc(filePathLen);
if (!filePath) {
fprintf(stderr, "Error while malloc");
return;
}
memcpy(filePath, basePath, baseLen);
memcpy(filePath + baseLen, DIR_SEPARATOR, sepLen);
memcpy(filePath + baseLen + sepLen, fname, fnameLen);
filePath[filePathLen] = '\0';
} else {
filePath = (char *)fname;
}
int fd = open(filePath, O_RDONLY);
if (fd < 0) {
fprintf(stderr, "Error while opening %s\n", filePath);
return;
// Already mmaped. Nothing to do.
if (constAddr[0] != NULL)
return true;

char *filePath;
char *basePath = getenv("OM_CONSTANT_PATH");
if (basePath) {
size_t baseLen = strlen(basePath);
size_t fnameLen = strlen(fname);
size_t sepLen = strlen(DIR_SEPARATOR);
size_t filePathLen = baseLen + sepLen + fnameLen + 1;
filePath = (char *)malloc(filePathLen);
if (!filePath) {
fprintf(stderr, "Error while malloc: %s", strerror(errno));
return false;
}
snprintf(filePath, filePathLen, "%s%s%s", basePath, DIR_SEPARATOR, fname);
} else {
filePath = (char *)fname;
}
int fd = open(filePath, O_RDONLY);
if (fd < 0) {
fprintf(stderr, "Error while opening %s: %s\n", filePath, strerror(errno));
if (basePath)
free(filePath);
return false;
}

#ifdef __MVS__
void *tempAddr = mmap(0, size, PROT_READ, __MAP_MEGA, fd, 0);
#else
void *tempAddr = mmap(0, size, PROT_READ, MAP_SHARED, fd, 0);
#endif

if (tempAddr == MAP_FAILED) {
fprintf(stderr, "Error while mmapping %s\n", fname);
close(fd);
return;
}

/* Prepare to compare-and-swap to setup the shared constAddr.
* If we fail, another thread beat us so free our mmap.
*/
#ifdef __MVS__
void *expected = NULL;
if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)tempAddr))
munmap(tempAddr, size);
#else
if (!__sync_bool_compare_and_swap(&constAddr[0], NULL, tempAddr))
munmap(tempAddr, size);
#endif

/* Either we succeeded in setting constAddr or someone else did it.
* Either way, constAddr is now setup. We can close our fd without
* invalidating the mmap.
*/
if (tempAddr == MAP_FAILED) {
fprintf(stderr, "Error while mmapping %s: %s\n", fname, strerror(errno));
close(fd);
if (basePath)
free(filePath);
return false;
}

/* Prepare to compare-and-swap to setup the shared constAddr.
* If we fail, another thread beat us so free our mmap.
*/
#ifdef __MVS__
free(tPath);
void *expected = NULL;
if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)&tempAddr))
munmap(tempAddr, size);
#else
if (!__sync_bool_compare_and_swap(&constAddr[0], NULL, tempAddr))
munmap(tempAddr, size);
#endif

/* Either we succeeded in setting constAddr or someone else did it.
* Either way, constAddr is now setup. We can close our fd without
* invalidating the mmap.
*/
close(fd);
if (basePath)
free(filePath);
return true;
}

/// Return the address of a constant at a given offset.
Expand All @@ -153,11 +141,11 @@ void omMMapBinaryFile(
void omGetExternalConstantAddr(
void **outputAddr, void **baseAddr, int64_t offset) {
if (outputAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer.");
return;
}
if (baseAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to track these error by returning false or true?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can be here, but I am a bit hesitated since this is called per each constant so there would have a lot of if then else generated in the main inference for checking the return value of this function.

return;
}
// Constant is already loaded. Nothing to do.
Expand Down
Loading
Loading