From b6742695226139d9e970d316272485895a28b6ee Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Tue, 9 Jul 2024 18:15:46 -0700 Subject: [PATCH] [FRONTEND] let CacheManager write to temp dir instead of temp file (#4295) # Summary there've been multiple issues discussing around the `FileNotFoundError` on compilation when `CompiledKernel` is trying to read from the listed ASM files. #2688 #4002 https://github.com/vllm-project/vllm/issues/6103 etc. and there have been some attempts to address it such as #3544 . This PR attempts to explain the root cause and suggest a fix. # Why When a kernel is being compiled, triton first writes IRs to triton cache dir ([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L289)). Inside of the write operation, the process first writes it to a temp file unique to the current process (plus a uuid to distinguish between multiple processes with same PID on different hosts sharing the same underlying FS) ([ref](https://github.com/triton-lang/triton/blob/c14b033cd979d5c39e5fdb3847c022fa5d71a0c1/python/triton/runtime/cache.py#L124-L130)) and then atomically `os.replace` it to the final file name. Afterwards the `CompiledKernel` lists all the IRs and reads them ([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L362-L367)). On multiprocess set up this may however result in a race condition. Let's focus on a case where there's one host with 2 processes on it. ![Triton RC (1)](https://github.com/triton-lang/triton/assets/43726198/ffc20e0c-0404-4e7a-bd6c-022e710e97b9) At the time when `pid 1` lists ASMs, the dir may contain temp files generated from another process `pid 2`. However at the time when `pid 1` proceeds to read bytes from the listed files, `pid2` may have already `os.replace`ed its temp files, so `pid 1` will encounter `FileNotFoundError` when trying to read the temp file generated by `pid 2`. https://github.com/IBM/vllm/pull/35#issuecomment-2145591548 also believes this is the root cause. # How There're multiple potential solutions towards this, as mentioned in https://github.com/IBM/vllm/pull/35#issuecomment-2145591548 as well: - let each process write to a private temp dir instead so `glob` won't bother taking the temp stuff into consideration - or, exclude `tmp.pid_*` from `glob` This PR tries to go with the 1st approach to avoid adding an assumption on the tmp file pattern (which is only used in `runtime/cache.py`) in `compiler/compiler.py` but is open to any suggestion. Thanks! Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `not applicable`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --- python/triton/runtime/cache.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index 4b87108e0cd4..d5e777ee9a6c 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -120,14 +120,18 @@ def put(self, data, filename, binary=True) -> str: rnd_id = str(uuid.uuid4()) # we use the PID in case a bunch of these around so we can see what PID made it pid = os.getpid() - # use tempfile to be robust against program interruptions - temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, filename) + mode = "wb" if binary else "w" with open(temp_path, mode) as f: f.write(data) # Replace is guaranteed to be atomic on POSIX systems if it succeeds # so filepath cannot see a partial write os.replace(temp_path, filepath) + os.removedirs(temp_dir) return filepath