Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test the return value of omMMapBinaryFile function and terminate the main program elegantly #3002

Merged
merged 13 commits into from
Nov 15, 2024

Conversation

tungld
Copy link
Collaborator

@tungld tungld commented Nov 8, 2024

Currently the main program still runs even if omMMapBinaryFile fails to mmap the constant file, so it's difficult for debugging.

This patch changes omMMapBinaryFile to return a boolean value indicating success and failure, so that the main program can exit elegantly.

This patch also includes a fix for z/OS we just found.

create.llvm._return(create.llvm.null(getI8PointerType(context)));
});
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this code into KrnlToLLVMHelper.{hpp/cpp} so that it can be reused.

@tungld
Copy link
Collaborator Author

tungld commented Nov 8, 2024

@christopherlmunoz FYI. This is not to fix issues about mmapping a constant file on z/OS, but to make debugging easier.

@tungld tungld requested a review from gongsu832 November 8, 2024 22:13
Signed-off-by: Tung D. Le <[email protected]>
}

/* Prepare to compare-and-swap to setup the shared constAddr.
* If we fail, another thread beat us so free our mmap.
*/
#ifdef __MVS__
void *expected = NULL;
if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)tempAddr))
if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)&tempAddr))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include a fix for z/OS.

@@ -67,15 +68,18 @@ void omMMapBinaryFile(
char *tPath = strdup(fname);
if (!tPath) {
fprintf(stderr, "Error while strdup");
return;
return false;
}
__a2e_s(tPath);
fname = tPath;
#endif

if (constAddr == NULL) {
perror("Error: null pointer");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of mixing fprintf and perror, either use fprintf everywhere but add strerror(errno) for more informative printout, or use perror everywhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, now using fprintf everywhere.

@@ -89,7 +93,10 @@ void omMMapBinaryFile(
filePath = (char *)malloc(filePathLen);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to +1 to filePathLen to account for the \0 you add to the end later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I updated this.

#ifdef __MVS__
free(tPath);
#endif
return false;
}
memcpy(filePath, basePath, baseLen);
Copy link
Collaborator

@gongsu832 gongsu832 Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memcpy and adding \0 can be replaced by a single snprintf:

sprintf(filePath, filePathLen, "%s%s%s", basePath, DIR_SEPARATOR, fname);

(assuming filePathLen was incremented by 1 above).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I updated by using snprintf

@@ -43,7 +43,7 @@ const int i = 1;
void checkEndianness(const char constPackIsLE) {
if (XOR(IS_SYSTEM_LE(), constPackIsLE)) {
fprintf(stderr, "Constant pack is stored in a byte order that is not "
"native to this current system.");
"native to this current system: %s", strerror(errno));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I probably didn't make it clear. errno is only meaningful after you make a library call such as strdup, open, malloc, etc. Here strerror(errno) is meaningless.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thought. Will remove them.

return false;
}
__a2e_s(tPath);
fname = tPath;
#endif

if (constAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer: %s", strerror(errno));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This strerror(errno) is also meaningless.

@@ -158,7 +155,9 @@ bool omMMapBinaryFile(
/* Make sure constAddr is the same as the mmap address.
*/
if (constAddr[0] != tempAddr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is wrong. Only the first thread that successfully performs the compare-and-swap will have constAddr[0] == tempAddr. All other threads will have constAddr[0] != tempAddr. But it doesn't matter since by design they will throw away their tempAddr and use the constAddr[0] set by the first thread.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does tempAddr become NULL after munmap? If so, we do the test only when tempAddr != NULL.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explicitly set tempAddr = NULL after munmap, and check constAddr[0] != tempAddr only when tempAddr != NULL (meaning for the first successful thread).

Copy link
Collaborator

@gongsu832 gongsu832 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does tempAddr become NULL after munmap? If so, we do the test only when tempAddr != NULL.

Not sure what you mean. munmap cannot change tempAddr since C is passing by value (i.e., inside munmap it only has access to a copy of tempAddr). But even if tempAddr somehow becomes NULL after munmap, the check constAddr[0] != tempAddr is still wrong.

Copy link
Collaborator

@gongsu832 gongsu832 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explicitly set tempAddr = NULL after munmap, and check constAddr[0] != tempAddr only when tempAddr != NULL (meaning for the first successful thread).

Not sure why you want to do that. A successful compare-and-swap guarantees that constAddr[0] == tempAddr so the check would be redundant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you are right. In the latest code, I set tempAddr = NULL for failed threads, and check if (tempAddr && (constAddr[0] != tempAddr)). What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to avoid the situation we encountered where cds was written wrongly. Perhaps, assert is better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I removed the check. I was overthinking about having a check to debug, but since we found the issue, and cds would work correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you can use assert but it wouldn't help in this case. Because even if you write your cds wrong like the bug we just had, in the first thread after the cds you will have constAddr[0] == tempAddr. Because cds doesn't know/care if the tempAddr you supplied in the call is garbage. All it does is setting constAddr[0] to tempAddr atomically, regardless of what's in tempAddr.

@@ -181,11 +180,11 @@ bool omMMapBinaryFile(
void omGetExternalConstantAddr(
void **outputAddr, void **baseAddr, int64_t offset) {
if (outputAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer: %s", strerror(errno));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strerror(errno) is meaningless here.

return;
}
if (baseAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer: %s", strerror(errno));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strerror(errno) is meaningless here.

@tungld
Copy link
Collaborator Author

tungld commented Nov 13, 2024

@gongsu832 @AlexandreEichenberger could you take another look at this PR? Apart from handling return value of omMMapBinaryFile, it contains a bug fix for the issue #2993. Thanks!

Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please have @gongsu832 provide the feedback specific to the mmap part.

MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (!errorMsg.empty())
create.krnl.printf(StringRef(errorMsg + "\n"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we start to use the krnl.printf for more than debugging, I have been often experiencing issues with that call (in debugging situations) where it print too long a string. We have some issues with that call.

I had similar issue in the krnl.PrintTensor and I have alleviated this by having a special %e sequence at the end of the string. For Instrumentation, I think I encoded the string length in an integer flag. All tricks to get around this issue. If someone could try to understand what is happening, it would be great!

return;
}
if (baseAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to track these error by returning false or true?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can be here, but I am a bit hesitated since this is called per each constant so there would have a lot of if then else generated in the main inference for checking the return value of this function.

@@ -342,5 +343,47 @@ bool isZOS(ModuleOp module) {
return zOS;
}

void equalOrFailed(ModuleOp &module, OpBuilder &rewriter, Location loc,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these calls also be potentially useful if we wanted to fail the inference, say after a NNPA severe failure call?
If so, we may want to see if we would benefit from introducing something at a higher level dialects (say Krnl or SCF) that would then feed into this one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally yes. It will be used anywhere in the main inference when we want to fail the inference. We can introduce a krnl op for this.

fprintf(stderr, "Error: null pointer.");
#ifdef __MVS__
free(tPath);
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a big deal but a bit of an eye sore to see the #ifdef __MVS__ everywhere. Maybe it's simpler to just define two macros:

#ifdef __MVS__
#define CONV_PATH(p) \
  char *p = strdup(fname); \
  if (!p) { \
    fprintf(stderr, "Error while strdup: %s", strerror(errno)); \
    return false; \
  } \
  __a2e_s(p); \
  fname = p
#define FREE_PATH(p) free(p)
#else
#define CONV_PATH(p)
#define FREE_PATH(p)
#endif

and then just use CONV_PATH(tPath) and FREE_PATH(tPath) everywhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gongsu832 all of these issues come from the fact that fname is not in EBCDIC. So I modify the code generation to convert fname to EBCDIC during compilation so that fname here would be alread in EBCDIC, then there is no need to handle it here anymore.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that makes things simpler.

Copy link
Collaborator

@gongsu832 gongsu832 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tungld tungld merged commit 868432d into onnx:main Nov 15, 2024
7 checks passed
@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #15991 [push] Add the return value to ... started at 20:25

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #15994 [push] Add the return value to ... started at 21:25

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15021 [push] Add the return value to ... started at 21:40

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #15991 [push] Add the return value to ... passed after 1 hr 22 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #15994 [push] Add the return value to ... passed after 1 hr 29 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15021 [push] Add the return value to ... passed after 2 hr 20 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants