Skip to content

Commit

Permalink
[FRONTEND] let CacheManager write to temp dir instead of temp file (t…
Browse files Browse the repository at this point in the history
…riton-lang#4295)

# Summary
there've been multiple issues discussing around the `FileNotFoundError`
on compilation when `CompiledKernel` is trying to read from the listed
ASM files. triton-lang#2688 triton-lang#4002 vllm-project/vllm#6103
etc. and there have been some attempts to address it such as triton-lang#3544 .
This PR attempts to explain the root cause and suggest a fix.

# Why
When a kernel is being compiled, triton first writes IRs to triton cache
dir
([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L289)).
Inside of the write operation, the process first writes it to a temp
file unique to the current process (plus a uuid to distinguish between
multiple processes with same PID on different hosts sharing the same
underlying FS)
([ref](https://github.com/triton-lang/triton/blob/c14b033cd979d5c39e5fdb3847c022fa5d71a0c1/python/triton/runtime/cache.py#L124-L130))
and then atomically `os.replace` it to the final file name. Afterwards
the `CompiledKernel` lists all the IRs and reads them
([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L362-L367)).

On multiprocess set up this may however result in a race condition.
Let's focus on a case where there's one host with 2 processes on it.
![Triton RC
(1)](https://github.com/triton-lang/triton/assets/43726198/ffc20e0c-0404-4e7a-bd6c-022e710e97b9)

At the time when `pid 1` lists ASMs, the dir may contain temp files
generated from another process `pid 2`. However at the time when `pid 1`
proceeds to read bytes from the listed files, `pid2` may have already
`os.replace`ed its temp files, so `pid 1` will encounter
`FileNotFoundError` when trying to read the temp file generated by `pid
2`. IBM/vllm#35 (comment) also
believes this is the root cause.

# How
There're multiple potential solutions towards this, as mentioned in
IBM/vllm#35 (comment) as well:
- let each process write to a private temp dir instead so `glob` won't
bother taking the temp stuff into consideration
- or, exclude `tmp.pid_*` from `glob`

This PR tries to go with the 1st approach to avoid adding an assumption
on the tmp file pattern (which is only used in `runtime/cache.py`) in
`compiler/compiler.py` but is open to any suggestion. Thanks!

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [x] This PR does not need a test because `not applicable`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
  • Loading branch information
yundai424 authored and bertmaher committed Dec 4, 2024
1 parent 43d0cb8 commit 098fbf0
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/triton/runtime/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,18 @@ def put(self, data, filename, binary=True) -> str:
rnd_id = str(uuid.uuid4())
# we use the PID in case a bunch of these around so we can see what PID made it
pid = os.getpid()
# use tempfile to be robust against program interruptions
temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
# use temp dir to be robust against program interruptions
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, filename)

mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
os.removedirs(temp_dir)
return filepath


Expand Down

0 comments on commit 098fbf0

Please sign in to comment.