diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index 4b87108e0cd4..d5e777ee9a6c 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -120,14 +120,18 @@ def put(self, data, filename, binary=True) -> str: rnd_id = str(uuid.uuid4()) # we use the PID in case a bunch of these around so we can see what PID made it pid = os.getpid() - # use tempfile to be robust against program interruptions - temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, filename) + mode = "wb" if binary else "w" with open(temp_path, mode) as f: f.write(data) # Replace is guaranteed to be atomic on POSIX systems if it succeeds # so filepath cannot see a partial write os.replace(temp_path, filepath) + os.removedirs(temp_dir) return filepath