Skip to content

Commit

Permalink
ggml : add initial BF16 support
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Nov 6, 2024
1 parent 1dc04b2 commit 6109cf1
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 8 deletions.
3 changes: 3 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "bf16") {
return GGML_TYPE_BF16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
Expand Down
68 changes: 61 additions & 7 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
Expand Down Expand Up @@ -150,6 +151,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
Expand Down Expand Up @@ -195,6 +200,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
Expand Down Expand Up @@ -300,8 +306,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
Expand Down Expand Up @@ -615,6 +623,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
Expand All @@ -641,6 +650,10 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
Expand Down Expand Up @@ -690,6 +703,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
Expand Down Expand Up @@ -793,10 +807,12 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
Expand Down Expand Up @@ -887,8 +903,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {

static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
return false;
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
op->op != GGML_OP_GET_ROWS &&
op->op != GGML_OP_MUL_MAT &&
op->op != GGML_OP_VIEW &&
op->op != GGML_OP_CPY) {
GGML_LOG_ERROR("unsupported BF16 op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
GGML_ASSERT(false);
}
}

Expand Down Expand Up @@ -969,6 +990,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
Expand All @@ -980,11 +1002,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
return false;
}
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
return true;
default:
default:
return false;
}
default:
Expand Down Expand Up @@ -1855,6 +1879,7 @@ static void ggml_metal_encode_node(
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}

Expand All @@ -1863,6 +1888,7 @@ static void ggml_metal_encode_node(
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
Expand Down Expand Up @@ -1940,6 +1966,25 @@ static void ggml_metal_encode_node(
nrows = 4;
}
} break;
case GGML_TYPE_BF16:
{
nth0 = 32;
nth1 = 1;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
nrows = ne11;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
nrows = 4;
}
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
nrows = 4;
}
} break;
case GGML_TYPE_Q4_0:
{
nth0 = 8;
Expand Down Expand Up @@ -2438,6 +2483,7 @@ static void ggml_metal_encode_node(
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
Expand Down Expand Up @@ -3237,6 +3283,7 @@ static void ggml_metal_encode_node(
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
Expand All @@ -3254,6 +3301,13 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_BF16:
{
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
default: GGML_ABORT("not implemented");
}

Expand Down
17 changes: 16 additions & 1 deletion ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};

typedef matrix<bfloat, 4, 4> bfloat4x4;

// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
Expand All @@ -27,6 +29,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
reg = (type4x4)(*src);
}

template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}

template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
Expand Down Expand Up @@ -2041,6 +2048,8 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;

template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
Expand Down Expand Up @@ -2110,6 +2119,7 @@ kernel void kernel_mul_mv_1row(
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;

template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;

// Assumes row size (ne00) is a multiple of 4
template<typename T, typename T4>
Expand Down Expand Up @@ -2169,6 +2179,7 @@ kernel void kernel_mul_mv_l4(
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;

template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
Expand Down Expand Up @@ -3567,8 +3578,10 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;

template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;

kernel void kernel_cpy_f32_q8_0(
device const float * src0,
Expand Down Expand Up @@ -6473,6 +6486,7 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;

template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;

typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;

Expand Down Expand Up @@ -6504,6 +6518,7 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de

template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
Expand Down

0 comments on commit 6109cf1

Please sign in to comment.