Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test the return value of omMMapBinaryFile function and terminate the main program elegantly #3002

Merged
merged 13 commits into from
Nov 15, 2024
Merged
20 changes: 13 additions & 7 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,15 +612,15 @@ void loadConstantsFromFile(ModuleOp &module,
OpBuilder b(ctx);
MultiDialectBuilder<LLVMBuilder> create(b, loc);

Type llvmI1Ty = IntegerType::get(ctx, 1);
Type llvmI8Ty = IntegerType::get(ctx, 8);
Type llvmI64Ty = IntegerType::get(ctx, 64);
Type llvmI8PtrTy = getPointerType(ctx, llvmI8Ty);
Type llvmVoidTy = LLVM::LLVMVoidType::get(ctx);

// The following function will be emitted inside the IR to load constants from
// file.
std::string loadAllConstantsFuncName = "omLoadConstantsFromFile";
Type llvmFnType = LLVM::LLVMFunctionType::get(llvmVoidTy, {}, false);
Type llvmFnType = LLVM::LLVMFunctionType::get(llvmI1Ty, {}, false);

// If calledByEntryPoint, this function will be called by entry points.
// Otherwise, user program (C/C++/Java/Python) would call this function.
Expand All @@ -629,6 +629,7 @@ void loadConstantsFromFile(ModuleOp &module,
Operation *firstEntryPointOp =
getFirstEntryOpInBlock(module, entryGlobalOps);
assert(firstEntryPointOp && "No entry function exists");
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(firstEntryPointOp);
funcOp = create.llvm.func(
loadAllConstantsFuncName, llvmFnType, /*createUniqueFunc=*/true);
Expand All @@ -646,13 +647,16 @@ void loadConstantsFromFile(ModuleOp &module,
std::find(entryName.begin(), entryName.end(), '\0'), entryName.end());
auto entryFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(entryName);
assert(entryFunc && "Entry function not found");
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(
&entryFunc.getBody().front(), entryFunc.getBody().front().begin());
FlatSymbolRefAttr loadAllConstantsRef = create.llvm.getOrInsertSymbolRef(
module, LLVMBuilder::SymbolPostfix(module, loadAllConstantsFuncName),
llvmVoidTy, {},
llvmI1Ty, {},
/*isVarArg=*/false);
create.llvm.call({}, loadAllConstantsRef, {});
Value retVal = create.llvm.call({llvmI1Ty}, loadAllConstantsRef, {});
equalOrFailed(module, b, loc,
create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)), retVal);
}
} else {
OpBuilder::InsertionGuard guard(b);
Expand Down Expand Up @@ -697,8 +701,11 @@ void loadConstantsFromFile(ModuleOp &module,
// Call a function to mmap the binary file to memory.
Value isleVal = create.llvm.constant(llvmI64Ty, isle);
Value sizeVal = create.llvm.constant(llvmI64Ty, dataSize);
RuntimeAPI::callApi(b, loc, apiRegistry, RuntimeAPI::API::MMAP_BINARY_FILE,
Value retVal = RuntimeAPI::callApi(b, loc, apiRegistry,
RuntimeAPI::API::MMAP_BINARY_FILE,
{packedGlobalPtr, fnameI8Ptr, sizeVal, isleVal});
equalOrReturn(module, b, loc,
create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)), retVal, retVal);

// Now set pointers for constants in the IR
module->walk([&](LLVM::GlobalOp dataGlobalOp) -> WalkResult {
Expand All @@ -725,11 +732,10 @@ void loadConstantsFromFile(ModuleOp &module,
RuntimeAPI::callApi(b, loc, apiRegistry,
RuntimeAPI::API::GET_EXTERNAL_CONSTANT_ADDR,
{dataPtr, packedGlobalPtr, offsetVal});

return WalkResult::advance();
});

create.llvm._return();
create.llvm._return(create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)));
}

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 0 additions & 25 deletions src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,31 +412,6 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
rewriter.getI64Type(), {rewriter.getI64Type()});
}

// Emit code for `IF lhs != rhs THEN return null ELSE do nothing`
void equalOrFailed(ModuleOp &module, PatternRewriter &rewriter, Location loc,
Value lhs, Value rhs, std::string errorMsg = "",
bool appendRHS = true) const {
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (appendRHS)
create.krnl.printf(
StringRef(errorMsg), rhs, rewriter.getI64Type(), true);
else
create.krnl.printf(StringRef(errorMsg + "\n"));
// Set errno.
krnl::emitErrNo(module, rewriter, loc, EINVAL);
// Return NULL.
create.llvm._return(create.llvm.null(getI8PointerType(context)));
});
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this code into KrnlToLLVMHelper.{hpp/cpp} so that it can be reused.

void emitVerificationCodeForInputTensors(ModuleOp &module,
PatternRewriter &rewriter, Location loc,
const RuntimeAPIRegistry &apiRegistry, Value omTensorInputs,
Expand Down
44 changes: 44 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "onnx-mlir/Compiler/OMCompilerRuntimeTypes.h"
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand Down Expand Up @@ -342,5 +343,48 @@ bool isZOS(ModuleOp module) {
return zOS;
}

void equalOrFailed(ModuleOp &module, OpBuilder &rewriter, Location loc,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these calls also be potentially useful if we wanted to fail the inference, say after a NNPA severe failure call?
If so, we may want to see if we would benefit from introducing something at a higher level dialects (say Krnl or SCF) that would then feed into this one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally yes. It will be used anywhere in the main inference when we want to fail the inference. We can introduce a krnl op for this.

Value lhs, Value rhs, std::string errorMsg, bool appendRHS) {
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (!errorMsg.empty()) {
if (appendRHS)
create.krnl.printf(
StringRef(errorMsg), rhs, rewriter.getI64Type(), true);
else
create.krnl.printf(StringRef(errorMsg + "\n"));
}
// Set errno.
emitErrNo(module, rewriter, loc, EINVAL);
// Return NULL.
create.llvm._return(create.llvm.null(getI8PointerType(context)));
});
}

void equalOrReturn(ModuleOp &module, OpBuilder &rewriter, Location loc,
Value lhs, Value rhs, Value retVal, std::string errorMsg) {
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (!errorMsg.empty())
create.krnl.printf(StringRef(errorMsg + "\n"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we start to use the krnl.printf for more than debugging, I have been often experiencing issues with that call (in debugging situations) where it print too long a string. We have some issues with that call.

I had similar issue in the krnl.PrintTensor and I have alleviated this by having a special %e sequence at the end of the string. For Instrumentation, I think I encoded the string length in an integer flag. All tricks to get around this issue. If someone could try to understand what is happening, it would be great!

// Return retVal.
create.llvm._return(retVal);
});
}

} // namespace krnl
} // namespace onnx_mlir
10 changes: 10 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ std::string e2a_s(std::string e_s);
void emitErrNo(mlir::ModuleOp module, mlir::OpBuilder &builder,
mlir::Location loc, int err);

/// Emit code for `IF lhs != rhs THEN return null ELSE do nothing`.
void equalOrFailed(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
std::string errorMsg = "", bool appendRHS = true);

/// Emit code for `IF lhs != rhs THEN return retVal ELSE do nothing`.
void equalOrReturn(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs, mlir::Value retVal,
std::string errorMsg = "");

/// Creates an LLVM pointer type with the given element type and address space.
/// This function is meant to be used in code supporting both typed and opaque
/// pointers, as it will create an opaque pointer with the given address space
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/KrnlToLLVM/RuntimeAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ RuntimeAPIRegistry::RuntimeAPIRegistry(
: registry() {
MLIRContext *context = module.getContext();
auto voidTy = LLVM::LLVMVoidType::get(context);
Type int1Ty = IntegerType::get(context, 1);
auto int8Ty = IntegerType::get(context, 8);
auto opaquePtrTy = onnx_mlir::krnl::getPointerType(context, int8Ty);
auto opaquePtrPtrTy = onnx_mlir::krnl::getPointerType(context, opaquePtrTy);
Expand All @@ -88,7 +89,7 @@ RuntimeAPIRegistry::RuntimeAPIRegistry(
RuntimeAPI(API::GET_OMT_ARRAY, "omTensorListGetOmtArray", opaquePtrPtrTy, {opaquePtrTy}),
RuntimeAPI(API::PRINT_OMTENSOR, "omTensorPrint", voidTy, {opaquePtrTy, opaquePtrTy}),
RuntimeAPI(API::GET_OMTENSOR_LIST_SIZE, "omTensorListGetSize", int64Ty, {opaquePtrTy}),
RuntimeAPI(API::MMAP_BINARY_FILE, "omMMapBinaryFile", voidTy, {opaquePtrPtrTy, opaquePtrTy, int64Ty, int64Ty}),
RuntimeAPI(API::MMAP_BINARY_FILE, "omMMapBinaryFile", int1Ty, {opaquePtrPtrTy, opaquePtrTy, int64Ty, int64Ty}),
RuntimeAPI(API::GET_EXTERNAL_CONSTANT_ADDR, "omGetExternalConstantAddr", voidTy, {opaquePtrPtrTy, opaquePtrPtrTy, int64Ty}),
};
// clang-format on
Expand Down
67 changes: 47 additions & 20 deletions src/Runtime/OMExternalConstant.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ typedef int make_iso_compilers_happy;

#include <errno.h>
#include <inttypes.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
Expand All @@ -42,7 +43,7 @@ const int i = 1;
void checkEndianness(const char constPackIsLE) {
if (XOR(IS_SYSTEM_LE(), constPackIsLE)) {
fprintf(stderr, "Constant pack is stored in a byte order that is not "
"native to this current system.");
"native to this current system: %s", strerror(errno));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I probably didn't make it clear. errno is only meaningful after you make a library call such as strdup, open, malloc, etc. Here strerror(errno) is meaningless.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thought. Will remove them.

exit(1);
}
}
Expand All @@ -58,24 +59,27 @@ void checkEndianness(const char constPackIsLE) {
///
/// This function is thread-safe.
///
void omMMapBinaryFile(
bool omMMapBinaryFile(
void **constAddr, char *filename, int64_t size, int64_t isLE) {
checkEndianness(isLE);
char *fname = filename;
#ifdef __MVS__
// Convert the file name to EBCDIC for the open call.
char *tPath = strdup(fname);
if (!tPath) {
fprintf(stderr, "Error while strdup");
return;
fprintf(stderr, "Error while strdup: %s", strerror(errno));
return false;
}
__a2e_s(tPath);
fname = tPath;
#endif

if (constAddr == NULL) {
perror("Error: null pointer");
return;
fprintf(stderr, "Error: null pointer: %s", strerror(errno));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This strerror(errno) is also meaningless.

#ifdef __MVS__
free(tPath);
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a big deal but a bit of an eye sore to see the #ifdef __MVS__ everywhere. Maybe it's simpler to just define two macros:

#ifdef __MVS__
#define CONV_PATH(p) \
  char *p = strdup(fname); \
  if (!p) { \
    fprintf(stderr, "Error while strdup: %s", strerror(errno)); \
    return false; \
  } \
  __a2e_s(p); \
  fname = p
#define FREE_PATH(p) free(p)
#else
#define CONV_PATH(p)
#define FREE_PATH(p)
#endif

and then just use CONV_PATH(tPath) and FREE_PATH(tPath) everywhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gongsu832 all of these issues come from the fact that fname is not in EBCDIC. So I modify the code generation to convert fname to EBCDIC during compilation so that fname here would be alread in EBCDIC, then there is no need to handle it here anymore.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that makes things simpler.

return false;
}

if (constAddr[0] == NULL) {
Expand All @@ -85,23 +89,28 @@ void omMMapBinaryFile(
size_t baseLen = strlen(basePath);
size_t fnameLen = strlen(fname);
size_t sepLen = strlen(DIR_SEPARATOR);
size_t filePathLen = baseLen + sepLen + fnameLen;
size_t filePathLen = baseLen + sepLen + fnameLen + 1;
filePath = (char *)malloc(filePathLen);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to +1 to filePathLen to account for the \0 you add to the end later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I updated this.

if (!filePath) {
fprintf(stderr, "Error while malloc");
return;
fprintf(stderr, "Error while malloc: %s", strerror(errno));
#ifdef __MVS__
free(tPath);
#endif
return false;
}
memcpy(filePath, basePath, baseLen);
memcpy(filePath + baseLen, DIR_SEPARATOR, sepLen);
memcpy(filePath + baseLen + sepLen, fname, fnameLen);
filePath[filePathLen] = '\0';
snprintf(filePath, filePathLen, "%s%s%s", basePath, DIR_SEPARATOR, fname);
} else {
filePath = (char *)fname;
}
int fd = open(filePath, O_RDONLY);
if (fd < 0) {
fprintf(stderr, "Error while opening %s\n", filePath);
return;
fprintf(stderr, "Error while opening %s: %s\n", filePath, strerror(errno));
if (basePath)
free(filePath);
#ifdef __MVS__
free(tPath);
#endif
return false;
}

#ifdef __MVS__
Expand All @@ -111,17 +120,22 @@ void omMMapBinaryFile(
#endif

if (tempAddr == MAP_FAILED) {
fprintf(stderr, "Error while mmapping %s\n", fname);
fprintf(stderr, "Error while mmapping %s: %s\n", fname, strerror(errno));
close(fd);
return;
if (basePath)
free(filePath);
#ifdef __MVS__
free(tPath);
#endif
return false;
}

/* Prepare to compare-and-swap to setup the shared constAddr.
* If we fail, another thread beat us so free our mmap.
*/
#ifdef __MVS__
void *expected = NULL;
if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)tempAddr))
if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)&tempAddr))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include a fix for z/OS.

munmap(tempAddr, size);
#else
if (!__sync_bool_compare_and_swap(&constAddr[0], NULL, tempAddr))
Expand All @@ -135,11 +149,24 @@ void omMMapBinaryFile(
close(fd);
if (basePath)
free(filePath);
#ifdef __MVS__
free(tPath);
#endif
/* Make sure constAddr is the same as the mmap address.
*/
if (constAddr[0] != tempAddr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is wrong. Only the first thread that successfully performs the compare-and-swap will have constAddr[0] == tempAddr. All other threads will have constAddr[0] != tempAddr. But it doesn't matter since by design they will throw away their tempAddr and use the constAddr[0] set by the first thread.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does tempAddr become NULL after munmap? If so, we do the test only when tempAddr != NULL.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explicitly set tempAddr = NULL after munmap, and check constAddr[0] != tempAddr only when tempAddr != NULL (meaning for the first successful thread).

Copy link
Collaborator

@gongsu832 gongsu832 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does tempAddr become NULL after munmap? If so, we do the test only when tempAddr != NULL.

Not sure what you mean. munmap cannot change tempAddr since C is passing by value (i.e., inside munmap it only has access to a copy of tempAddr). But even if tempAddr somehow becomes NULL after munmap, the check constAddr[0] != tempAddr is still wrong.

Copy link
Collaborator

@gongsu832 gongsu832 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explicitly set tempAddr = NULL after munmap, and check constAddr[0] != tempAddr only when tempAddr != NULL (meaning for the first successful thread).

Not sure why you want to do that. A successful compare-and-swap guarantees that constAddr[0] == tempAddr so the check would be redundant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you are right. In the latest code, I set tempAddr = NULL for failed threads, and check if (tempAddr && (constAddr[0] != tempAddr)). What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to avoid the situation we encountered where cds was written wrongly. Perhaps, assert is better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I removed the check. I was overthinking about having a check to debug, but since we found the issue, and cds would work correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you can use assert but it wouldn't help in this case. Because even if you write your cds wrong like the bug we just had, in the first thread after the cds you will have constAddr[0] == tempAddr. Because cds doesn't know/care if the tempAddr you supplied in the call is garbage. All it does is setting constAddr[0] to tempAddr atomically, regardless of what's in tempAddr.

fprintf(stderr,
"Error while updating the buffer address for constants: %s\n",
strerror(errno));
return false;
}
return true;
}

#ifdef __MVS__
free(tPath);
#endif
return true;
}

/// Return the address of a constant at a given offset.
Expand All @@ -153,11 +180,11 @@ void omMMapBinaryFile(
void omGetExternalConstantAddr(
void **outputAddr, void **baseAddr, int64_t offset) {
if (outputAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer: %s", strerror(errno));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strerror(errno) is meaningless here.

return;
}
if (baseAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer: %s", strerror(errno));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strerror(errno) is meaningless here.

return;
}
// Constant is already loaded. Nothing to do.
Expand Down
Loading
Loading