From 0662d81ccaa5b22ac45c960cf334cc664e957d1c Mon Sep 17 00:00:00 2001 From: zhentaoyu Date: Mon, 12 Aug 2024 09:19:03 +0000 Subject: [PATCH] sycl: fix convert and dequantize Signed-off-by: zhentaoyu --- ggml/src/ggml-sycl/convert.cpp | 4 ++-- ggml/src/ggml-sycl/dequantize.hpp | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index c6b40c34f2866f..d44629ad59e51a 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -29,7 +29,7 @@ template static void dequantize_block_sycl(const void *__restrict__ vx, dst_t *__restrict__ y, const int64_t k, dpct::queue_ptr stream) { - const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); + const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -435,7 +435,7 @@ template static void convert_unary_sycl(const void *__restrict__ vx, dst_t *__restrict__ y, const int64_t k, dpct::queue_ptr stream) { - const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; + const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; // decrease global range when it exceeds the max int int local_size = SYCL_DEQUANTIZE_BLOCK_SIZE; diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 6b25c144889763..8f4041fffce335 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -15,9 +15,9 @@ #include "common.hpp" -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); -static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_1 * x = (const block_q4_1 *) vx; @@ -64,7 +64,7 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_0 * x = (const block_q5_0 *) vx; @@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_1 * x = (const block_q5_1 *) vx; @@ -118,7 +118,7 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q8_0 * x = (const block_q8_0 *) vx; @@ -138,7 +138,7 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, } template -static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, +static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { const int64_t i = item_ct1.get_group(2); @@ -168,7 +168,7 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri } template -static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, +static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { const int64_t i = item_ct1.get_group(2);