diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index e878f4f50f09e2..cf5291b31fe917 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try { << ", line:" << __LINE__ << std::endl; std::exit(1); } + +int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { + const int64_t max_range = std::numeric_limits::max(); + int64_t sycl_down_blk_size = block_size; + int64_t global_range = accumulate_block_num * sycl_down_blk_size; + while(global_range > max_range) { + sycl_down_blk_size /= 2; + global_range = accumulate_block_num * sycl_down_blk_size; + } + return sycl_down_blk_size; +} diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 86d8b40e8b0133..1dbf4370906065 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -352,4 +352,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { return acc.template get_multi_ptr().get(); } +int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 03d6ff9928a708..5fd15e6cdccabb 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -437,13 +437,7 @@ static void convert_unary_sycl(const void *__restrict__ vx, const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; // decrease global range when it exceeds the max int - int local_size = SYCL_DEQUANTIZE_BLOCK_SIZE; - const int64_t max_range = std::numeric_limits::max(); - int64_t global_range = num_blocks * local_size; - while(global_range > max_range) { - local_size /= 2; - global_range = num_blocks * local_size; - } + int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE); sycl::range<3> block_nums(1, 1, num_blocks); sycl::range<3> local_range(1, 1, local_size); { diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 1c304866b55052..6a0a0fcd08c68a 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -64,13 +64,7 @@ static void im2col_sycl( const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; // decrease global range when it exceeds the max int - int local_size = SYCL_IM2COL_BLOCK_SIZE; - const int64_t max_range = std::numeric_limits::max(); - int64_t global_range = batch * IC * OH * num_blocks * local_size; - while(global_range > max_range) { - local_size /= 2; - global_range = batch * IC * OH * num_blocks * local_size; - } + int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); sycl::range<3> block_nums(batch * IC, OH, num_blocks); sycl::range<3> local_range(1, 1, local_size);