From 2b4ac7770dd5098cdfb70dcf6faedc9399a5eb79 Mon Sep 17 00:00:00 2001 From: woachk <24752637+woachk@users.noreply.github.com> Date: Sun, 31 Mar 2024 08:15:45 +0200 Subject: [PATCH] kompute: implement op_getrows_f32 op_getrows_f32 is required since https://github.com/ggerganov/llama.cpp/pull/6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again. --- CMakeLists.txt | 2 ++ ggml-kompute.cpp | 17 +++++++++++++--- kompute-shaders/op_getrows_f32.comp | 31 +++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 kompute-shaders/op_getrows_f32.comp diff --git a/CMakeLists.txt b/CMakeLists.txt index 19fdfa46ca4f15..945a51fc432841 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -709,6 +709,7 @@ if (LLAMA_KOMPUTE) kompute-shaders/op_mul_mat_q4_0.comp kompute-shaders/op_mul_mat_q4_1.comp kompute-shaders/op_mul_mat_q6_k.comp + kompute-shaders/op_getrows_f32.comp kompute-shaders/op_getrows_f16.comp kompute-shaders/op_getrows_q4_0.comp kompute-shaders/op_getrows_q4_1.comp @@ -741,6 +742,7 @@ if (LLAMA_KOMPUTE) shaderop_mul_mat_q4_0.h shaderop_mul_mat_q4_1.h shaderop_mul_mat_q6_k.h + shaderop_getrows_f32.h shaderop_getrows_f16.h shaderop_getrows_q4_0.h shaderop_getrows_q4_1.h diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 407062e6fd4762..827ce815c2fe25 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -22,6 +22,7 @@ #include "shaderop_mul_mat_q4_1.h" #include "shaderop_mul_mat_q6_k.h" #include "shaderop_mul_mat_mat_f32.h" +#include "shaderop_getrows_f32.h" #include "shaderop_getrows_f16.h" #include "shaderop_getrows_q4_0.h" #include "shaderop_getrows_q4_1.h" @@ -136,8 +137,7 @@ static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_devi physical_device.getFeatures2(&features2); - if (!availableFeatures11.uniformAndStorageBuffer16BitAccess || - !availableFeatures11.storageBuffer16BitAccess) { + if (!availableFeatures11.storageBuffer16BitAccess) { return false; } @@ -1146,6 +1146,14 @@ static void ggml_vk_get_rows( seq.record(s_algo); } +template +static void ggml_vk_get_rows_f32(Args&&... args) { + const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv, + kp::shader_data::op_getrows_f32_comp_spv_len); + + ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward(args)...); +} + template static void ggml_vk_get_rows_f16(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv, @@ -1371,6 +1379,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) { return op->ne[3] == 1; case GGML_OP_GET_ROWS: switch (op->src[0]->type) { + case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -1649,7 +1658,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml } break; case GGML_OP_GET_ROWS: { - if (src0t == GGML_TYPE_F16) { + if (src0t == GGML_TYPE_F32) { + ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); + } else if (src0t == GGML_TYPE_F16) { ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); } else if (src0t == GGML_TYPE_Q4_0) { ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); diff --git a/kompute-shaders/op_getrows_f32.comp b/kompute-shaders/op_getrows_f32.comp new file mode 100644 index 00000000000000..9d7acdaf8a8e4c --- /dev/null +++ b/kompute-shaders/op_getrows_f32.comp @@ -0,0 +1,31 @@ +#version 450 + +#include "common.comp" + +layout(local_size_x = 1) in; + +layout (binding = 0) readonly buffer tensorInA { float inA[]; }; +layout (binding = 1) readonly buffer tensorInB { int inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int nb01; + int nb1; +} pcs; + +void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) { + for (int j = 0; j < k; j++) { + out_[y + j] = inA[x + j]; + } +} + +void main() { + const uint i = gl_WorkGroupID.x; + const int r = inB[i + pcs.inBOff]; + + dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00); +}