Skip to content

Commit

Permalink
kompute: implement op_getrows_f32
Browse files Browse the repository at this point in the history
op_getrows_f32 is required since #6122
for the Vulkan w/ Kompute backend to be functional.

As such, implement this op to make this backend functional again.
  • Loading branch information
woachk committed Mar 31, 2024
1 parent 37e7854 commit 2b4ac77
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ if (LLAMA_KOMPUTE)
kompute-shaders/op_mul_mat_q4_0.comp
kompute-shaders/op_mul_mat_q4_1.comp
kompute-shaders/op_mul_mat_q6_k.comp
kompute-shaders/op_getrows_f32.comp
kompute-shaders/op_getrows_f16.comp
kompute-shaders/op_getrows_q4_0.comp
kompute-shaders/op_getrows_q4_1.comp
Expand Down Expand Up @@ -741,6 +742,7 @@ if (LLAMA_KOMPUTE)
shaderop_mul_mat_q4_0.h
shaderop_mul_mat_q4_1.h
shaderop_mul_mat_q6_k.h
shaderop_getrows_f32.h
shaderop_getrows_f16.h
shaderop_getrows_q4_0.h
shaderop_getrows_q4_1.h
Expand Down
17 changes: 14 additions & 3 deletions ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "shaderop_mul_mat_q4_1.h"
#include "shaderop_mul_mat_q6_k.h"
#include "shaderop_mul_mat_mat_f32.h"
#include "shaderop_getrows_f32.h"
#include "shaderop_getrows_f16.h"
#include "shaderop_getrows_q4_0.h"
#include "shaderop_getrows_q4_1.h"
Expand Down Expand Up @@ -136,8 +137,7 @@ static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_devi

physical_device.getFeatures2(&features2);

if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
!availableFeatures11.storageBuffer16BitAccess) {
if (!availableFeatures11.storageBuffer16BitAccess) {
return false;
}

Expand Down Expand Up @@ -1146,6 +1146,14 @@ static void ggml_vk_get_rows(
seq.record<kp::OpAlgoDispatch>(s_algo);
}

template <typename... Args>
static void ggml_vk_get_rows_f32(Args&&... args) {
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
kp::shader_data::op_getrows_f32_comp_spv_len);

ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
}

template <typename... Args>
static void ggml_vk_get_rows_f16(Args&&... args) {
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
Expand Down Expand Up @@ -1371,6 +1379,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
return op->ne[3] == 1;
case GGML_OP_GET_ROWS:
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
Expand Down Expand Up @@ -1649,7 +1658,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
} break;
case GGML_OP_GET_ROWS:
{
if (src0t == GGML_TYPE_F16) {
if (src0t == GGML_TYPE_F32) {
ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
} else if (src0t == GGML_TYPE_F16) {
ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
} else if (src0t == GGML_TYPE_Q4_0) {
ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
Expand Down
31 changes: 31 additions & 0 deletions kompute-shaders/op_getrows_f32.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#version 450

#include "common.comp"

layout(local_size_x = 1) in;

layout (binding = 0) readonly buffer tensorInA { float inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };

layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;

void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
for (int j = 0; j < k; j++) {
out_[y + j] = inA[x + j];
}
}

void main() {
const uint i = gl_WorkGroupID.x;
const int r = inB[i + pcs.inBOff];

dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
}

0 comments on commit 2b4ac77

Please sign in to comment.