Skip to content

Commit

Permalink
sycl: fix ib in dmmv
Browse files Browse the repository at this point in the history
Signed-off-by: zhentaoyu <[email protected]>
  • Loading branch information
zhentaoyu committed Aug 16, 2024
1 parent 0662d81 commit 3386879
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ggml/src/ggml-sycl/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ template <typename src_t, typename dst_t>
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) {
const int64_t work_group_size = item_ct1.get_local_range(2);
const int64_t global_id = item_ct1.get_local_id(2) + item_ct1.get_group(2) * work_group_size;
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);

// make each work-item deal with more elements since sycl global range can not exceed max int
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-sycl/dmmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
#include "presets.hpp"


static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const sycl::half *x = (const sycl::half *)vx;

// automatic half -> float type cast if dfloat == float
v.x() = x[ib + iqs + 0];
v.y() = x[ib + iqs + 1];
}

static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const float * x = (const float *) vx;

// automatic half -> float type cast if dfloat == float
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-sycl/im2col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ static void im2col_kernel(
int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
const sycl::nd_item<3> &item_ct1) {
const int64_t work_group_size = item_ct1.get_local_range(2);
const int64_t global_id = item_ct1.get_local_id(2) + item_ct1.get_group(2) * work_group_size;
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);

// make each work-item deal with more elements since sycl global range can not exceed max int
for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
Expand Down Expand Up @@ -95,7 +95,7 @@ void ggml_sycl_op_im2col(

GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);

const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
Expand Down

0 comments on commit 3386879

Please sign in to comment.