From 80321f4b49dbabba01c422d43b700a1b48f65a03 Mon Sep 17 00:00:00 2001 From: Calvin Laurenson <89622328+calvin-laurenson@users.noreply.github.com> Date: Sat, 15 Jun 2024 17:35:57 -0700 Subject: [PATCH 1/8] cuda sqrt support --- ggml-cuda.cu | 3 +++ ggml-cuda/unary.cu | 28 ++++++++++++++++++++++++++++ ggml-cuda/unary.cuh | 3 +++ 3 files changed, 34 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 593fa4cdaa514..6e7ee676a49a5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2267,6 +2267,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SQR: ggml_cuda_op_sqr(ctx, dst); break; + case GGML_OP_SQRT: + ggml_cuda_op_sqrt(ctx, dst); + break; case GGML_OP_CLAMP: ggml_cuda_op_clamp(ctx, dst); break; diff --git a/ggml-cuda/unary.cu b/ggml-cuda/unary.cu index a5ff96320f23f..5c3f716d7c840 100644 --- a/ggml-cuda/unary.cu +++ b/ggml-cuda/unary.cu @@ -92,6 +92,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) { dst[i] = x[i] * x[i]; } +static __global__ void sqrt_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = sqrt(x[i]); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -142,6 +151,11 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t sqr_f32<<>>(x, dst, k); } +static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE; + sqrt_f32<<>>(x, dst, k); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -284,3 +298,17 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} \ No newline at end of file diff --git a/ggml-cuda/unary.cuh b/ggml-cuda/unary.cuh index a1d07c04fcd43..4cfb0479e7169 100644 --- a/ggml-cuda/unary.cuh +++ b/ggml-cuda/unary.cuh @@ -8,6 +8,7 @@ #define CUDA_HARDSIGMOID_BLOCK_SIZE 256 #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 +#define CUDA_SQRT_BLOCK_SIZE 256 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -28,3 +29,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 3c4df6ccf3c445f98080eafe882f1a10e20b3a55 Mon Sep 17 00:00:00 2001 From: Calvin Laurenson <89622328+calvin-laurenson@users.noreply.github.com> Date: Sat, 15 Jun 2024 17:36:09 -0700 Subject: [PATCH 2/8] enable cuda in pca --- examples/cvector-generator/pca.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 8b95cec374c23..4fb82db1787bc 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -65,13 +65,13 @@ struct pca_model { pca_model(struct ggml_tensor * t_input) { // TODO: enable GPU support when support for GGML_OP_SQRT is added -// #ifdef GGML_USE_CUDA -// fprintf(stderr, "%s: using CUDA backend\n", __func__); -// backend = ggml_backend_cuda_init(0); // init device 0 -// if (!backend) { -// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); -// } -// #endif +#ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + backend = ggml_backend_cuda_init(0); // init device 0 + if (!backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } +#endif // #ifdef GGML_USE_METAL // fprintf(stderr, "%s: using Metal backend\n", __func__); From 145f09fc9206d50bc4fb99748cda1b4a32de8fb5 Mon Sep 17 00:00:00 2001 From: Calvin Laurenson <89622328+calvin-laurenson@users.noreply.github.com> Date: Sat, 15 Jun 2024 18:05:08 -0700 Subject: [PATCH 3/8] fix comments in pca --- examples/cvector-generator/pca.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 4fb82db1787bc..36eadaac26a12 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -64,7 +64,6 @@ struct pca_model { struct ggml_tensor * dev_eigenvector; pca_model(struct ggml_tensor * t_input) { -// TODO: enable GPU support when support for GGML_OP_SQRT is added #ifdef GGML_USE_CUDA fprintf(stderr, "%s: using CUDA backend\n", __func__); backend = ggml_backend_cuda_init(0); // init device 0 @@ -73,6 +72,7 @@ struct pca_model { } #endif +// TODO: enable Metal support when support for GGML_OP_SQRT is added // #ifdef GGML_USE_METAL // fprintf(stderr, "%s: using Metal backend\n", __func__); // backend = ggml_backend_metal_init(); From 8ad3edfc406a1f41e84c9d96f295d295f31f3f3c Mon Sep 17 00:00:00 2001 From: Calvin Laurenson <89622328+calvin-laurenson@users.noreply.github.com> Date: Sat, 15 Jun 2024 18:29:36 -0700 Subject: [PATCH 4/8] add test --- tests/test-backend-ops.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2b48e623e3476..4cf05b2832ed1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1063,6 +1063,26 @@ struct test_sqr : public test_case { } }; +// GGML_OP_SQRT +struct test_sqrt : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_sqrt(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_sqrt(ctx, a); + return out; + } +}; + // GGML_OP_CLAMP struct test_clamp : public test_case { const ggml_type type; @@ -2200,6 +2220,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } test_cases.emplace_back(new test_sqr()); + test_cases.emplace_back(new test_sqrt()); test_cases.emplace_back(new test_clamp()); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); From f9926746d9cdbd4bbe781850295376c7f42ba1fe Mon Sep 17 00:00:00 2001 From: Calvin Laurenson <89622328+calvin-laurenson@users.noreply.github.com> Date: Sat, 15 Jun 2024 18:47:20 -0700 Subject: [PATCH 5/8] add sqrt to ggml_backend_cuda_supports_op --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6e7ee676a49a5..b8298ab205e60 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2833,6 +2833,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_CLAMP: case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: From 97b37313ac69ffd6cb938063b26da24aa84d5a82 Mon Sep 17 00:00:00 2001 From: Calvin Laurenson <89622328+calvin-laurenson@users.noreply.github.com> Date: Sat, 15 Jun 2024 19:30:43 -0700 Subject: [PATCH 6/8] fix test --- tests/test-backend-ops.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4cf05b2832ed1..7c504e937a851 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1081,6 +1081,13 @@ struct test_sqrt : public test_case { ggml_tensor * out = ggml_sqrt(ctx, a); return out; } + + void initialize_tensors(ggml_context * ctx) override { + // fill with positive values + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, 0.0f, 100.0f); + } + } }; // GGML_OP_CLAMP From fea1dc98c0fdebd65243500dbbc97070ff8ca265 Mon Sep 17 00:00:00 2001 From: Calvin Laurenson <89622328+calvin-laurenson@users.noreply.github.com> Date: Sat, 15 Jun 2024 19:38:14 -0700 Subject: [PATCH 7/8] new line --- ggml-cuda/unary.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda/unary.cu b/ggml-cuda/unary.cu index 5c3f716d7c840..c0ddb8a56562a 100644 --- a/ggml-cuda/unary.cu +++ b/ggml-cuda/unary.cu @@ -311,4 +311,4 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); -} \ No newline at end of file +} From 3591f1cc8ca313b91ff823ee65f140e56a472b96 Mon Sep 17 00:00:00 2001 From: Calvin Laurenson Date: Sun, 16 Jun 2024 08:00:05 -0700 Subject: [PATCH 8/8] Use F32 sqrtf instead of F64 sqrt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml-cuda/unary.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda/unary.cu b/ggml-cuda/unary.cu index c0ddb8a56562a..f9e208011e2a8 100644 --- a/ggml-cuda/unary.cu +++ b/ggml-cuda/unary.cu @@ -98,7 +98,7 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) { if (i >= k) { return; } - dst[i] = sqrt(x[i]); + dst[i] = sqrtf(x[i]); } static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {