Skip to content

Commit

Permalink
Implemented vector length agnostic SVE using switch case for 512-bit,…
Browse files Browse the repository at this point in the history
… 256-bit, 128-bit vector lengths
  • Loading branch information
Vithulep committed Sep 3, 2024
1 parent 48baa61 commit 4dbdb6c
Showing 1 changed file with 226 additions and 36 deletions.
262 changes: 226 additions & 36 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -3818,14 +3818,20 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
float sumf = 0;

#if defined(__ARM_FEATURE_SVE)
if (ggml_sve_cnt_b == QK8_0) {
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);

svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);

for (; ib + 1 < nb; ib += 2) {
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);
assert(nb % 2 == 0); // TODO: handle odd nb
const int vector_length = ggml_sve_cnt_b*8;

// VLA Implementation using switch case
switch(vector_length)
{
case 128:
// predicate for activating higher lanes for 4 float32 elements
const svbool_t pg =svptrue_pat_b32(SV_VL4);

for (; ib + 1 < nb; ib += 2) {
const block_q4_0 * restrict x0 = &x[ib + 0];
const block_q4_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib + 0];
Expand All @@ -3836,24 +3842,113 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);

// 4-bit -> 8-bit
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(),qx0r, 0x0F));
const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(),qx0r, 0x04));
const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(),qx1r, 0x0F));
const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));

// sub 8
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);

// load y
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);

const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs+16);
const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs+16);
// dot product
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));

sumv0 = svmla_n_f32_x(pg, sumv0, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx0ls, qy0l),svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(pg, sumv1, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx1ls, qy1l),svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));

}

sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));

break;

case 256:
// predicate for activating higher lanes for 16 int8 elements
const svbool_t ptrueh_256 = svptrue_pat_b8(SV_VL16);
// predicate for activating lower lanes for 16 int8 elements
const svbool_t ptruel_256 = svnot_b_z(svptrue_b8(), ptrueh_256);


for (; ib + 1 < nb; ib += 2) {
const block_q4_0 * restrict x0 = &x[ib + 0];
const block_q4_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib + 0];
const block_q8_0 * restrict y1 = &y[ib + 1];

// load x
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);

// 4-bit -> 8-bit
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel_256, svand_n_u8_m(ptrueh_256, qx0r, 0x0F), 0x04));
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel_256, svand_n_u8_m(ptrueh_256, qx1r, 0x0F), 0x04));

// sub 8
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);

// load y
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);

// dot product
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
}

sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));

break;

case 512:
// predicate for activating higher lanes for 32 int8 elements
const svbool_t ptrue = svptrue_pat_b8(SV_VL32);
// predicate for activating higher lanes for 16 int8 elements
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
// predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
const svbool_t ptruel = svnot_b_z(ptrue, ptrueh);

for (; ib < nb; ib += 2) {
const block_q4_0 * restrict x0 = &x[ib + 0];
const block_q4_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib + 0];
const block_q8_0 * restrict y1 = &y[ib + 1];

// load x
const svuint8_t qx0r = svld1rq_u8(ptrue, x0->qs);
const svuint8_t qx1r = svld1rq_u8(ptrue, x1->qs);

// 4-bit -> 8-bit
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));

// sub 8
const svint8_t qx0s = svsub_n_s8_x(ptrue, qx0, 8);
const svint8_t qx1s = svsub_n_s8_x(ptrue, qx1, 8);

// load y
const svint8_t qy0 = svld1_s8(ptrue, y0->qs);
const svint8_t qy1 = svld1_s8(ptrue, y1->qs);

// dot product
sumv0 = svmla_n_f32_x(ptrue, sumv0, svcvt_f32_s32_x(ptrue, svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(ptrue, sumv1, svcvt_f32_s32_x(ptrue, svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
}
sumf = svaddv_f32(ptrue, svadd_f32_x(ptrue, sumv0, sumv1));
break;

default:
break;

}

#elif defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
Expand Down Expand Up @@ -5303,29 +5398,124 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
float sumf = 0;

#if defined(__ARM_FEATURE_SVE)
if (ggml_sve_cnt_b == QK8_0) {
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);

for (; ib + 1 < nb; ib += 2) {
const block_q8_0 * restrict x0 = &x[ib + 0];
const block_q8_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib + 0];
const block_q8_0 * restrict y1 = &y[ib + 1];

// load x
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);

// load y
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);

assert(nb % 2 == 0); // TODO: handle odd nb
const int vector_length = ggml_sve_cnt_b*8;

//VLA Implemenation for SVE
switch(vector_length)
{
case 128:
// predicate for activating lanes for 16 Int8 elements
svbool_t pg1 =svptrue_pat_b8(SV_VL16);
svbool_t pg =svptrue_pat_b32(SV_VL4);
for (; ib + 1 < nb; ib += 2) {

const block_q8_0 * restrict x0 = &x[ib + 0];
const block_q8_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib + 0];
const block_q8_0 * restrict y1 = &y[ib + 1];

// load x
const svint8_t qx0_0 = svld1_s8(pg1, x0->qs);
const svint8_t qx0_1 = svld1_s8(pg1, x0->qs+16);
const svint8_t qx1_0 = svld1_s8(pg1, x1->qs);
const svint8_t qx1_1 = svld1_s8(pg1, x1->qs+16);

// load y
const svint8_t qy0_0 = svld1_s8(pg1, y0->qs);
const svint8_t qy0_1 = svld1_s8(pg1, y0->qs+16);
const svint8_t qy1_0 = svld1_s8(pg1, y1->qs);
const svint8_t qy1_1 = svld1_s8(pg1, y1->qs+16);

sumv0 = svmla_n_f32_x(pg, sumv0, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(pg, sumv1, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));

}

sumf = svaddv_f32(pg, svadd_f32_x(pg, sumv0, sumv1));
break;

case 256:
//printf("sve256");
for (; ib + 1 < nb; ib += 2) {
const block_q8_0 * restrict x0 = &x[ib + 0];
const block_q8_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib + 0];
const block_q8_0 * restrict y1 = &y[ib + 1];

// load x
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);

// load y
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);

sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));

}
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
break;

case 512:
// predicate for activating high 256 bit
const svbool_t ptrueh = svptrue_pat_b8(SV_VL32);
// predicate for activating low 256 bit
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);

// predicate for activating high lanes for 8 float32 elements
svbool_t asd = svptrue_pat_b32(SV_VL8);
// predicate for activating low lanes for 8 float32 elements
svbool_t dsa = svnot_b_z(svptrue_b32(), asd);

svfloat32_t sumv00 = svdup_n_f32(0.0f);

for (; ib+1 < nb; ib += 2) {

const block_q8_0 * restrict x0 = &x[ib + 0];
const block_q8_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib + 0];
const block_q8_0 * restrict y1 = &y[ib + 1];

//load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
// and add them to make one 64 element vector
// load x
const svint8_t qx_32 = svld1_s8(ptrueh,x0->qs);
svint8_t qx_64 = svld1_s8(ptruel,x0->qs+2);
qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);

// load y
const svint8_t qy_32 = svld1_s8(ptrueh,y0->qs);
svint8_t qy_64 = svld1_s8(ptruel,y0->qs+2);
qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);

// scale creation
float32_t deq1= GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d);
float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d);

// duplicate deq1 in first half of vector and deq2 in second half of vector
svfloat32_t temp = svdup_f32_m(svdup_f32_z(asd, deq1), dsa,deq2);


svfloat32_t sumvt = svdup_n_f32(0.0f);

sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));

sumv00 = svmla_f32_m(svptrue_b32(),sumv00,sumvt,temp);

}

sumf = svaddv_f32(svptrue_b32(), sumv00);
break;

sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
}
default:
break;

sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
}
#elif defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
Expand Down

0 comments on commit 4dbdb6c

Please sign in to comment.