Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: multithread pipeline creation #963

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions src/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <unordered_map>
#include <memory>
#include <mutex>
#include <future>
#include <thread>

#include "ggml.h"
#include "ggml-backend-impl.h"
Expand Down Expand Up @@ -607,13 +609,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx

GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);

static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
// variables to track number of compiles in progress
static uint32_t compile_count = 0;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;

static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
GGML_ASSERT(parameter_count > 0);
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT

std::lock_guard<std::mutex> guard(device->mutex);

pipeline = std::make_shared<vk_pipeline_struct>();
pipeline->name = name;
pipeline->parameter_count = parameter_count;
Expand Down Expand Up @@ -681,7 +686,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co
pipeline->layout);
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;

device->pipelines.insert({ pipeline->name, pipeline });
{
std::lock_guard<std::mutex> guard(device->mutex);
device->pipelines.insert({ pipeline->name, pipeline });
}

{
std::lock_guard<std::mutex> guard(compile_count_mutex);
assert(compile_count > 0);
compile_count--;
}
compile_count_cond.notify_all();
}

static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
Expand Down Expand Up @@ -1190,6 +1205,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();

std::vector<std::future<void>> compiles;
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
{
// wait until fewer than N compiles are in progress
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
std::unique_lock<std::mutex> guard(compile_count_mutex);
while (compile_count >= N) {
compile_count_cond.wait(guard);
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
};

if (device->fp16) {
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
Expand Down Expand Up @@ -1739,6 +1768,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);

for (auto &c : compiles) {
c.wait();
}
}

static vk_device ggml_vk_get_device(size_t idx) {
Expand Down
Loading