Skip to content

Commit

Permalink
metal : optimize ggml_mul_mat_id (faster Mixtral PP) (ggerganov#4725)
Browse files Browse the repository at this point in the history
* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (ggerganov#4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

* metal : optimizing ggml_mul_mat_id (wip)

* metal : minor fix

* metal : opt mul_mm_id
  • Loading branch information
ggerganov authored Jan 2, 2024
1 parent 0ef3ca2 commit f3f62f0
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 46 deletions.
31 changes: 20 additions & 11 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute(
}
};

if (ggml_is_quantized(src0t)) {
GGML_ASSERT(ne00 >= nth0*nth1);
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
Expand Down Expand Up @@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute(
// TODO: make this more general
GGML_ASSERT(n_as <= 8);

// max size of the src1ids array in the kernel stack
GGML_ASSERT(ne11 <= 512);

struct ggml_tensor * src2 = gf->nodes[i]->src[2];

const int64_t ne20 = src2 ? src2->ne[0] : 0;
Expand All @@ -1732,32 +1739,29 @@ void ggml_metal_graph_compute(
GGML_ASSERT(!ggml_is_transposed(src2));
GGML_ASSERT(!ggml_is_transposed(src1));

GGML_ASSERT(ne20 % 32 == 0);
// !!!!!!!!! TODO: this assert is probably required but not sure!
//GGML_ASSERT(ne20 >= 64);
GGML_ASSERT(src1t == GGML_TYPE_F32);

const uint r2 = ne12/ne22;
const uint r3 = ne13/ne23;

// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
int ne11_mm_min = 1;
int ne11_mm_min = n_as;

const int idx = ((int32_t *) dst->op_params)[0];

// batch size
GGML_ASSERT(ne01 == ne11);

const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory

// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
// !!!
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
// indirect matrix multiplication
// !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne20 % 32 == 0 && ne20 >= 64 &&
ne11 > ne11_mm_min) {
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
Expand Down Expand Up @@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
Expand All @@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute(

[encoder setThreadgroupMemoryLength:8192 atIndex:0];

// TODO: processing one row at a time (ne11 -> 1) is not efficient
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
Expand Down Expand Up @@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute(
} break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
GGML_ASSERT(false && "not implemented");
}
};

if (ggml_is_quantized(src2t)) {
GGML_ASSERT(ne20 >= nth0*nth1);
}

const int64_t _ne1 = 1; // kernels needs a reference in constant memory

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
Expand Down
Loading

0 comments on commit f3f62f0

Please sign in to comment.