From bc632308aca5901a81f4047e329a66f1ccd98870 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 10 Jul 2024 13:47:49 -0700 Subject: [PATCH] Parallelize SPIR-V compilation (#4200) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4200 Due to a high number of shaders to compile, shader compilation has become quite slow (taking around 1 minute from empirical observations). This change updates the SPIR-V compilation script to parallelize SPIR-V compilation over the available number of CPU cores. Testing the impact of parallelization using the following command: ``` buck build //xplat/executorch/backends/vulkan:gen_vulkan_graph_runtime_shaderlib_cpp --out ~/scratch/shaders ``` We can observe the following improvement: | | Before | After| |--- | --- | --- | | M1 Mac | 42.4s | 7.8s | Reviewed By: liuk22, jorgep31415 Differential Revision: D59597186 fbshipit-source-id: 72df756a3a1f818688af66bbce4fc9fc02f1f505 --- backends/vulkan/runtime/gen_vulkan_spv.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 5661aed4c8..f5cfba3142 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -15,6 +15,7 @@ import re import sys from itertools import product +from multiprocessing.pool import ThreadPool sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import subprocess @@ -620,9 +621,12 @@ def constructOutputMap(self) -> None: def generateSPV(self, output_dir: str) -> Dict[str, str]: output_file_map = {} - for shader_name in self.output_shader_map: - source_glsl = self.output_shader_map[shader_name][0] - shader_params = self.output_shader_map[shader_name][1] + + def process_shader(shader_paths_pair): + shader_name = shader_paths_pair[0] + + source_glsl = shader_paths_pair[1][0] + shader_params = shader_paths_pair[1][1] with codecs.open(source_glsl, "r", encoding="utf-8") as input_file: input_text = input_file.read() @@ -652,9 +656,15 @@ def generateSPV(self, output_dir: str) -> Dict[str, str]: ] print("glslc cmd:", cmd) - # pyre-ignore subprocess.check_call(cmd) + return (spv_out_path, glsl_out_path) + + # Parallelize shader compilation as much as possible to optimize build time. + with ThreadPool(os.cpu_count()) as pool: + for spv_out_path, glsl_out_path in pool.map( + process_shader, self.output_shader_map.items() + ): output_file_map[spv_out_path] = glsl_out_path return output_file_map