Skip to content

Commit

Permalink
iq2_s
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Mar 14, 2024
1 parent 19885d2 commit 08d3b40
Showing 1 changed file with 204 additions and 2 deletions.
206 changes: 204 additions & 2 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

//
// MIT license
// Copyright (C) 2024 Intel Corporation
Expand Down Expand Up @@ -4732,6 +4733,36 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr

}

template<typename dst_t>
static void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1,
const uint64_t *iq2s_grid,
const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2);
const block_iq2_s * x = (const block_iq1_s *) vx;

const int tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]);
const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]);
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
#else
assert(false);
#endif

}


/*
DPCT1110:4: The total declared local variable size in device function
dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
Expand Down Expand Up @@ -7648,6 +7679,64 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
#endif
}

static __dpct_inline__ float
vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
const uint64_t *iq2s_grid, const uint64_t *ksigns64) {
#if QK_K == 256
const block_iq2_s * bq2 = (const block_iq2_s *) vbq;

const int ib32 = iqs;
const uint8_t * q8 = bq8_1[ib32].qs;
const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
const uint8_t ls2 = bq2->scales[ib32] >> 4;
int sumi1 = 0;
for (int l = 0; l < 2; ++l) {
const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
grid[0] ^ signs0, signs0, std::minus<>());
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
grid[1] ^ signs1, signs1, std::minus<>());
sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
q8 += 8;
}
int sumi2 = 0;
for (int l = 2; l < 4; ++l) {
const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
grid[0] ^ signs0, signs0, std::minus<>());
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
grid[1] ^ signs1, signs1, std::minus<>());
sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
q8 += 8;
}
const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
#else
(void) ksigns64;
assert(false);
return 0.f;
#endif
#else
(void) ksigns64;
assert(false);
return 0.f;
#endif
}



template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
vec_dot_q_mul_mat_sycl_t vec_dot>
Expand Down Expand Up @@ -8504,6 +8593,53 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
}
}


template <int qk, int qi, typename block_q_t, int vdr>
static void mul_mat_vec_q_iq2_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1,
const uint64_t *iq2s_grid_ptr, const uint64_t *ksigns64_ptr ) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);

if (row >= nrows) {
return;
}

const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;

// partial sum for each thread
float tmp = 0.0f;

const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;

for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index

const int iby = i * (qk/QK8_1); // y block index that aligns with ibx

const int iqs =
vdr *
(item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int

tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs, iq2s_grid_ptr, ksigns64_ptr);
}

// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}

if (item_ct1.get_local_id(2) == 0) {
dst[row] = tmp;
}
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1) {
Expand Down Expand Up @@ -10247,6 +10383,36 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
}
}

template <typename dst_t>
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
{
iq2s_grid.init(*stream);
ksigns_iq2xs.init(*stream);
kmask_iq2xs.init(*stream);

dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});

stream->submit([&](sycl::handler &cgh) {
auto iq2s_grid_ptr_ct1 = iq2s_grid.get_ptr();
auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();

cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_s(
vx, y, item_ct1, iq2s_grid_ptr_ct1,
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
});
});
}
}


template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int k,
Expand Down Expand Up @@ -10301,6 +10467,8 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
return dequantize_row_iq3_s_sycl;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
default:
Expand Down Expand Up @@ -10345,6 +10513,8 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
return dequantize_row_iq3_s_sycl;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
default:
Expand Down Expand Up @@ -10990,6 +11160,35 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
}
}

static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
iq2s_grid.init(*stream);
ksigns64.init(*stream);

stream->submit([&](sycl::handler &cgh) {
auto iq2s_grid_ptr_ct1 = iq2s_grid.get_ptr();
auto ksigns64_ptr_ct1 = ksigns64.get_ptr();

cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(32)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1,
iq2s_grid_ptr_ct1, ksigns64_ptr_ct1);
});
});
}
}


static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols_x,
const int nrows_x, const int ncols_y,
Expand Down Expand Up @@ -13738,6 +13937,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
return max_compute_capability >= VER_GEN9 ? 128 : 64;
case GGML_TYPE_IQ3_S:
Expand Down Expand Up @@ -13808,6 +14008,9 @@ inline void ggml_sycl_op_mul_mat_vec_q(
case GGML_TYPE_IQ1_S:
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
break;
default:
GGML_ASSERT(false);
break;
Expand Down Expand Up @@ -17153,8 +17356,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
return false;
}
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ2_S ||
a_type == GGML_TYPE_IQ4_XS) {
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS) {
return false;
}
return true;
Expand Down

0 comments on commit 08d3b40

Please sign in to comment.