Skip to content

Commit

Permalink
sycl: move downsample global_range into common
Browse files Browse the repository at this point in the history
Signed-off-by: zhentaoyu <[email protected]>
  • Loading branch information
zhentaoyu committed Aug 19, 2024
1 parent 3c1b017 commit 94c1ec9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
11 changes: 11 additions & 0 deletions ggml/src/ggml-sycl/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try {
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}

int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
const int64_t max_range = std::numeric_limits<int>::max();
int64_t sycl_down_blk_size = block_size;
int64_t global_range = accumulate_block_num * sycl_down_blk_size;
while(global_range > max_range) {
sycl_down_blk_size /= 2;
global_range = accumulate_block_num * sycl_down_blk_size;
}
return sycl_down_blk_size;
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
}

int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);

#endif // GGML_SYCL_COMMON_HPP
8 changes: 1 addition & 7 deletions ggml/src/ggml-sycl/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,7 @@ static void convert_unary_sycl(const void *__restrict__ vx,
const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;

// decrease global range when it exceeds the max int
int local_size = SYCL_DEQUANTIZE_BLOCK_SIZE;
const int64_t max_range = std::numeric_limits<int>::max();
int64_t global_range = num_blocks * local_size;
while(global_range > max_range) {
local_size /= 2;
global_range = num_blocks * local_size;
}
int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
sycl::range<3> block_nums(1, 1, num_blocks);
sycl::range<3> local_range(1, 1, local_size);
{
Expand Down
8 changes: 1 addition & 7 deletions ggml/src/ggml-sycl/im2col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,7 @@ static void im2col_sycl(
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;

// decrease global range when it exceeds the max int
int local_size = SYCL_IM2COL_BLOCK_SIZE;
const int64_t max_range = std::numeric_limits<int>::max();
int64_t global_range = batch * IC * OH * num_blocks * local_size;
while(global_range > max_range) {
local_size /= 2;
global_range = batch * IC * OH * num_blocks * local_size;
}
int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
sycl::range<3> local_range(1, 1, local_size);

Expand Down

0 comments on commit 94c1ec9

Please sign in to comment.