From 583fd6b000ec9ad1b465b5c98524f4a0ae388077 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 15 May 2024 08:44:16 +0200 Subject: [PATCH 1/8] server bench: fix bench not waiting for model load (#7284) --- examples/server/bench/bench.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 86c5de101445c..25ac29c4c7145 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -293,13 +293,14 @@ def server_log(in_stream, out_stream): def is_server_listening(server_fqdn, server_port): - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - result = sock.connect_ex((server_fqdn, server_port)) - _is_server_listening = result == 0 - if _is_server_listening: - print(f"server is listening on {server_fqdn}:{server_port}...") - return _is_server_listening - + try: + url = f"{server_fqdn}:{server_port}/health" + if not url.startswith("http://"): + url = f"http://{url}" + result = requests.get(url) + return result.status_code == 200 + except Exception: + return False def escape_metric_name(metric_name): return re.sub('[^A-Z0-9]', '_', metric_name.upper()) From 48aa8fd1f213a69b41569f809cc954f24dbc4366 Mon Sep 17 00:00:00 2001 From: John Balis Date: Wed, 15 May 2024 03:52:33 -0500 Subject: [PATCH 2/8] ggml : add `ggml_upscale_ext` (ggml/814) * initial commit with CPU implementation of upscale to shape and test, cuda implementation next * experimental commit to see if dst shape is correct * test version * test * removed unnecessary params * refactor * fixed tests * ggml : metal impl + cleanup + sycl dev warnings * patched ggml_upscale cuda op to handle non-contiguous tensors, added test for non-contiguous behavior * metal : fix upsacle op to support nb00 + style --------- Co-authored-by: Georgi Gerganov --- ggml-cuda/upscale.cu | 63 +++++++++++++++++++------------------ ggml-metal.m | 10 ++++-- ggml-metal.metal | 21 ++++++++----- ggml-sycl.cpp | 4 +++ ggml.c | 64 ++++++++++++++++++++++++++++---------- ggml.h | 12 +++++++ tests/test-backend-ops.cpp | 32 +++++++++++++++++-- 7 files changed, 146 insertions(+), 60 deletions(-) diff --git a/ggml-cuda/upscale.cu b/ggml-cuda/upscale.cu index 2f62fed488835..cf513c3ade7c4 100644 --- a/ggml-cuda/upscale.cu +++ b/ggml-cuda/upscale.cu @@ -1,35 +1,36 @@ #include "upscale.cuh" -static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) { - // blockIdx.z: idx of ne02*ne03 - // blockIdx.y: idx of ne01*scale_factor, aka ne1 - // blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE - // ne00xne01: ne00 * ne01 - int ne0 = ne00 * scale_factor; - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { +static __global__ void upscale_f32(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int ne13, + const float sf0, const float sf1, const float sf2, const float sf3) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index >= ne10 * ne11 * ne12 * ne13) { return; } - // operation - int i00 = nidx / scale_factor; - int i01 = blockIdx.y / scale_factor; - int offset_src = - i00 + - i01 * ne00 + - blockIdx.z * ne00xne01; - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - dst[offset_dst] = x[offset_src]; + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } -static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03, - const int scale_factor, cudaStream_t stream) { - int ne0 = (ne00 * scale_factor); - int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; - dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03); - upscale_f32<<>>(x, dst, ne00, ne00 * ne01, scale_factor); +static void upscale_f32_cuda(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int ne13, + const float sf0, const float sf1, const float sf2, const float sf3, + cudaStream_t stream) { + int dst_size = ne10 * ne11 * ne12 * ne13; + int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + + upscale_f32<<>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); } void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -39,10 +40,12 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int scale_factor = dst->op_params[0]; + const float sf0 = (float)dst->ne[0]/src0->ne[0]; + const float sf1 = (float)dst->ne[1]/src0->ne[1]; + const float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; - upscale_f32_cuda(src0_d, dst_d, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, stream); + upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); } diff --git a/ggml-metal.m b/ggml-metal.m index 390a1cd789081..b0b16dbf77160 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2353,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute( { GGML_ASSERT(src0->type == GGML_TYPE_F32); - const int sf = dst->op_params[0]; + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; @@ -2376,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf length:sizeof(sf) atIndex:18]; + [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; + [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; + [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; + [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); diff --git a/ggml-metal.metal b/ggml-metal.metal index 57fdf564e17f6..386e9195fcffa 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1852,7 +1852,10 @@ kernel void kernel_upscale_f32( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant int32_t & sf, + constant float & sf0, + constant float & sf1, + constant float & sf2, + constant float & sf3, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -1861,15 +1864,17 @@ kernel void kernel_upscale_f32( const int64_t i2 = tgpig.y; const int64_t i1 = tgpig.x; - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1/sf; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + const int64_t i03 = i3/sf3; + const int64_t i02 = i2/sf2; + const int64_t i01 = i1/sf1; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = src0_ptr[i0/sf]; + const int64_t i00 = i0/sf0; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_ptr[0] = src0_ptr[0]; } } diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 724070eb91080..b15efb70421da 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -13987,6 +13987,10 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0, GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors +#pragma message("TODO: generalize upscale operator") +#pragma message(" https://github.com/ggerganov/ggml/pull/814") + GGML_ASSERT(false && "TODO: generalize upscale operator); + const int scale_factor = dst->op_params[0]; upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream); diff --git a/ggml.c b/ggml.c index 03b609dddce10..f09cc30609bfd 100644 --- a/ggml.c +++ b/ggml.c @@ -6293,7 +6293,10 @@ struct ggml_tensor * ggml_pool_2d( static struct ggml_tensor * ggml_upscale_impl( struct ggml_context * ctx, struct ggml_tensor * a, - int scale_factor) { + int ne0, + int ne1, + int ne2, + int ne3) { bool is_node = false; if (a->grad) { @@ -6301,19 +6304,45 @@ static struct ggml_tensor * ggml_upscale_impl( is_node = true; } + GGML_ASSERT(a->ne[0] <= ne0); + GGML_ASSERT(a->ne[1] <= ne1); + GGML_ASSERT(a->ne[2] <= ne2); + GGML_ASSERT(a->ne[3] <= ne3); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, - a->ne[0] * scale_factor, - a->ne[1] * scale_factor, - a->ne[2], a->ne[3]); + ne0, + ne1, + ne2, + ne3 + ); result->op = GGML_OP_UPSCALE; - result->op_params[0] = scale_factor; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; return result; } +struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { + return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); +} + +struct ggml_tensor * ggml_upscale_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3) { + return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); +} + +// ggml_pad + struct ggml_tensor * ggml_pad( struct ggml_context * ctx, struct ggml_tensor * a, @@ -6338,12 +6367,7 @@ struct ggml_tensor * ggml_pad( return result; } -struct ggml_tensor * ggml_upscale( - struct ggml_context * ctx, - struct ggml_tensor * a, - int scale_factor) { - return ggml_upscale_impl(ctx, a, scale_factor); -} +// ggml_arange struct ggml_tensor * ggml_arange( struct ggml_context * ctx, @@ -6365,6 +6389,8 @@ struct ggml_tensor * ggml_arange( return result; } +// ggml_timestep_embedding + struct ggml_tensor * ggml_timestep_embedding( struct ggml_context * ctx, struct ggml_tensor * timesteps, @@ -14820,25 +14846,28 @@ static void ggml_compute_forward_upscale_f32( return; } - GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); const int ith = params->ith; const int nth = params->nth; GGML_TENSOR_UNARY_OP_LOCALS - const int scale_factor = dst->op_params[0]; + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; // TODO: optimize for (int64_t i3 = 0; i3 < ne3; i3++) { - const int64_t i03 = i3; + const int64_t i03 = i3 / sf3; for (int64_t i2 = ith; i2 < ne2; i2 += nth) { - const int64_t i02 = i2; + const int64_t i02 = i2 / sf2; for (int64_t i1 = 0; i1 < ne1; i1++) { - const int64_t i01 = i1 / scale_factor; + const int64_t i01 = i1 / sf1; for (int64_t i0 = 0; i0 < ne0; i0++) { - const int64_t i00 = i0 / scale_factor; + const int64_t i00 = i0 / sf0; const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); @@ -14868,6 +14897,7 @@ static void ggml_compute_forward_upscale( } } + // ggml_compute_forward_pad static void ggml_compute_forward_pad_f32( diff --git a/ggml.h b/ggml.h index 25f4f73a8d96b..5e121604a80e6 100644 --- a/ggml.h +++ b/ggml.h @@ -1674,12 +1674,24 @@ extern "C" { float p1); // nearest interpolate + // multiplies ne0 and ne1 by scale factor // used in stable-diffusion GGML_API struct ggml_tensor * ggml_upscale( struct ggml_context * ctx, struct ggml_tensor * a, int scale_factor); + // nearest interpolate + // nearest interpolate to specified dimensions + // used in tortoise.cpp + GGML_API struct ggml_tensor * ggml_upscale_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3); + // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] GGML_API struct ggml_tensor * ggml_pad( struct ggml_context * ctx, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f080f7e22cd35..85ef21c2a5c19 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1329,23 +1329,47 @@ struct test_upscale : public test_case { const ggml_type type; const std::array ne; const int32_t scale_factor; + const bool transpose; std::string vars() override { - return VARS_TO_STR3(type, ne, scale_factor); + return VARS_TO_STR4(type, ne, scale_factor, transpose); } test_upscale(ggml_type type = GGML_TYPE_F32, std::array ne = {512, 512, 3, 1}, - int32_t scale_factor = 2) - : type(type), ne(ne), scale_factor(scale_factor) {} + int32_t scale_factor = 2, bool transpose = false) + : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + if (transpose) a = ggml_transpose(ctx, a); ggml_tensor * out = ggml_upscale(ctx, a, scale_factor); return out; } }; +// GGML_OP_UPSCALE (ext) +struct test_upscale_ext : public test_case { + const ggml_type type; + const std::array ne; + const std::array ne_tgt; + + std::string vars() override { + return VARS_TO_STR3(type, ne, ne_tgt); + } + + test_upscale_ext(ggml_type type = GGML_TYPE_F32, + std::array ne = {2, 5, 7, 11}, + std::array ne_tgt = {5, 7, 11, 13}) + : type(type), ne(ne), ne_tgt(ne_tgt) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]); + return out; + } +}; + // GGML_OP_GROUP_NORM struct test_group_norm : public test_case { const ggml_type type; @@ -2169,6 +2193,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_upscale()); + test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true)); + test_cases.emplace_back(new test_upscale_ext()); test_cases.emplace_back(new test_group_norm()); test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_pad()); From 29499bb59383c2a8c5d557a90abb08b696cef7f6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 May 2024 13:23:41 +0300 Subject: [PATCH 3/8] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index b2b5f8fc14fa4..57bede67b4f19 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -fafd5e7f89382b8cfb51e3dac8d4f1500ca44918 +126d34985705a5a2222723c145cb4e125ac689f3 From ea3b0590ee33d3573eb8ef76f88cc60f36d2a38d Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 15 May 2024 20:01:12 +0800 Subject: [PATCH 4/8] embedding : free the batch after execution (#7297) --- examples/embedding/embedding.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index c85a2da53d129..0c921ed69badb 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -211,6 +211,7 @@ int main(int argc, char ** argv) { // clean up llama_print_timings(ctx); + llama_batch_free(batch); llama_free(ctx); llama_free_model(model); llama_backend_free(); From 9a17ab914b0aa7353389c656a3f2a0f086726868 Mon Sep 17 00:00:00 2001 From: AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> Date: Wed, 15 May 2024 13:26:30 +0100 Subject: [PATCH 5/8] Add missing " (#7303) --- ggml-sycl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index b15efb70421da..19d22d63753ab 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -13989,7 +13989,7 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0, #pragma message("TODO: generalize upscale operator") #pragma message(" https://github.com/ggerganov/ggml/pull/814") - GGML_ASSERT(false && "TODO: generalize upscale operator); + GGML_ASSERT(false && "TODO: generalize upscale operator"); const int scale_factor = dst->op_params[0]; From 344f9126cc0d15891fde9472fe40b8572628ad7d Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 15 May 2024 15:08:48 +0200 Subject: [PATCH 6/8] ggml : tag ggml_tensor::backend as deprecated (#7290) --- examples/llava/llava.cpp | 15 --------------- ggml-backend.c | 1 - ggml.c | 10 ++++++++++ ggml.h | 3 ++- 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 9a990bb182f35..63878d176b0bb 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -88,7 +88,6 @@ static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair< // Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out) static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) { struct { - struct ggml_tensor * newline; struct ggml_context * ctx; } model; @@ -150,20 +149,6 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector model.ctx = ggml_init(params); - ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip); - model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]); - if (newline_tmp->backend != GGML_BACKEND_TYPE_CPU) { - if (newline_tmp->buffer == NULL) { - LOG_TEE("newline_tmp tensor buffer is NULL\n"); - } - ggml_backend_tensor_get(newline_tmp, model.newline->data, 0, ggml_nbytes(newline_tmp)); - } else { - model.newline->data = newline_tmp->data; - if (model.newline->data == NULL) { - LOG_TEE("newline_tmp tensor data is NULL\n"); - } - } - struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4 // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); // fill it with the image embeddings, ignoring the base diff --git a/ggml-backend.c b/ggml-backend.c index dd090a583f685..9e35ce98d7ace 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1895,7 +1895,6 @@ void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * t tensor->buffer = buffer; tensor->data = (char *)tensor->view_src->data + tensor->view_offs; - tensor->backend = tensor->view_src->backend; ggml_backend_buffer_init_tensor(buffer, tensor); } diff --git a/ggml.c b/ggml.c index f09cc30609bfd..67e17a2102b16 100644 --- a/ggml.c +++ b/ggml.c @@ -3178,6 +3178,12 @@ static struct ggml_tensor * ggml_new_tensor_impl( struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); +#ifdef __clang__ + // temporary until ggml_tensor::backend is removed + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + *result = (struct ggml_tensor) { /*.type =*/ type, /*.backend =*/ GGML_BACKEND_TYPE_CPU, @@ -3200,6 +3206,10 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.padding =*/ { 0 }, }; +#ifdef __clang__ + #pragma clang diagnostic pop +#endif + // TODO: this should not be needed as long as we don't rely on aligned SIMD loads //ggml_assert_aligned(result->data); diff --git a/ggml.h b/ggml.h index 5e121604a80e6..8c13f4ba89c6e 100644 --- a/ggml.h +++ b/ggml.h @@ -565,7 +565,8 @@ extern "C" { // n-dimensional tensor struct ggml_tensor { enum ggml_type type; - enum ggml_backend_type backend; + + GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor"); struct ggml_backend_buffer * buffer; From dc020985b8755dd6aa93a2f002f43c3ede808cce Mon Sep 17 00:00:00 2001 From: agray3 Date: Wed, 15 May 2024 14:44:49 +0100 Subject: [PATCH 7/8] Avoid unnecessarily disabling CUDA graphs (#7302) As discussed in PR #6766, CUDA graphs were being disabled in the presence of long prompts. This fixes the issue by avoiding the consective update counter from incrementing unnecessarily for tokens in which cuda graphs are disabled due to batch size > 1. --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 75a2ad480877d..04b6e52859ed5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2558,7 +2558,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (cuda_graph_update_required) { + if (use_cuda_graph && cuda_graph_update_required) { cuda_ctx->cuda_graph->number_consecutive_updates++; } else { cuda_ctx->cuda_graph->number_consecutive_updates = 0; From e1b40ac3b94824d761b5e26ea1bc5692706029d9 Mon Sep 17 00:00:00 2001 From: kunnis Date: Wed, 15 May 2024 12:59:12 -0500 Subject: [PATCH 8/8] ggml : use dynamic thread scheduling for matrix multiplication (#6915) * Just reordering some structs. * Adding in the calls to mm_pause * Passing around the state * Renaming and moving a bunch of variables around. * Extracting the logic to it's own function. * Moving some variable definitions into the chunk function. * Moving some variables around * moving src1_cont inside * Moving row_size * adding the current_chunk * Reorg the code. * Formatting to match the orig patch * starting to setup the chunking variables * Starting the buildup of the loop * The yield shouldn't be necessary. * adding the looping structure based on the chunk configuration. * Add in the re-chunking code. * Making it much more likely to rechunk. * disable resizing if numa is enabled. * Updating comments with what we've learned. * Fix formatting * Couple more formatting fixes. * More style fixes. * Fix Warnings * Going with unused because there's conditional logic that needs it. * Update ggml.c * Update ggml.c --------- --- ggml.c | 381 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 237 insertions(+), 144 deletions(-) diff --git a/ggml.c b/ggml.c index 67e17a2102b16..684d84235f8d3 100644 --- a/ggml.c +++ b/ggml.c @@ -112,6 +112,8 @@ typedef void * thread_ret_t; #endif +typedef pthread_t ggml_thread_t; + #ifdef GGML_USE_CPU_HBM #include #endif @@ -1539,6 +1541,59 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) #endif +// +// ggml context +// + +struct ggml_context { + size_t mem_size; + void* mem_buffer; + bool mem_buffer_owned; + bool no_alloc; + bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers + + int n_objects; + + struct ggml_object* objects_begin; + struct ggml_object* objects_end; + + struct ggml_scratch scratch; + struct ggml_scratch scratch_save; +}; + +struct ggml_context_container { + bool used; + + struct ggml_context context; +}; + +struct ggml_compute_state_shared { + const struct ggml_cgraph* cgraph; + const struct ggml_cplan* cplan; + + int64_t perf_node_start_cycles; + int64_t perf_node_start_time_us; + + const int n_threads; + + // synchronization primitives + atomic_int n_active; // num active threads + atomic_int node_n; // active graph node + atomic_int node_task; // active graph node task phase + + ggml_abort_callback abort_callback; // abort ggml_graph_compute when true + void* abort_callback_data; + + atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. +}; + +struct ggml_compute_state { + ggml_thread_t thrd; + int ith; + struct ggml_compute_state_shared* shared; + enum ggml_status ec; +}; + // // fundamental operations // @@ -2385,32 +2440,6 @@ static void ggml_setup_op_has_task_pass(void) { } } -// -// ggml context -// - -struct ggml_context { - size_t mem_size; - void * mem_buffer; - bool mem_buffer_owned; - bool no_alloc; - bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers - - int n_objects; - - struct ggml_object * objects_begin; - struct ggml_object * objects_end; - - struct ggml_scratch scratch; - struct ggml_scratch scratch_save; -}; - -struct ggml_context_container { - bool used; - - struct ggml_context context; -}; - // // NUMA support // @@ -11815,9 +11844,101 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) { } #endif +static void ggml_compute_forward_mul_mat_one_chunk( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const int64_t num_rows_per_vec_dot, + const int64_t ir0_start, + const int64_t ir0_end, + const int64_t ir1_start, + const int64_t ir1_end) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + + const bool src1_cont = ggml_is_contiguous(src1); + + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + + //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); + + // threads with no work simply yield (not sure if it helps) + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; + + // attempt to reduce false-sharing (does not seem to make a difference) + // 16 * 2, accounting for mmla kernels + float tmp[32]; + + for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { + const int64_t i13 = (ir1 / (ne12 * ne1)); + const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; + const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char*)wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12 + i13 * nb13)); + float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + } + + for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { + memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); + } + } + } + } +} + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + struct ggml_tensor * dst, + struct ggml_compute_state * state) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -11832,9 +11953,6 @@ static void ggml_compute_forward_mul_mat( const enum ggml_type type = src0->type; - const bool src1_cont = ggml_is_contiguous(src1); - - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; int64_t const vec_dot_num_rows = type_traits[type].nrows; @@ -11855,8 +11973,10 @@ static void ggml_compute_forward_mul_mat( GGML_ASSERT(nb2 <= nb3); // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + UNUSED(r2); + UNUSED(r3); // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -11938,6 +12058,8 @@ static void ggml_compute_forward_mul_mat( #endif #if GGML_USE_LLAMAFILE + const bool src1_cont = ggml_is_contiguous(src1); + if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) @@ -11963,6 +12085,8 @@ UseGgmlGemm1:; if (ith != 0) { return; } + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + atomic_store(&state->shared->current_chunk, nth); if (src1->type != vec_dot_type) { char * wdata = params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -11987,11 +12111,11 @@ UseGgmlGemm1:; return; } - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - #if GGML_USE_LLAMAFILE if (src1->type != vec_dot_type) { + const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), @@ -12012,98 +12136,87 @@ UseGgmlGemm1:; UseGgmlGemm2:; #endif - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = ne1*ne12*ne13; // src1 rows - - //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); - - // distribute the thread work across the inner or outer loop based on which one is larger - - const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - const int64_t ith0 = ith % nth0; - const int64_t ith1 = ith / nth0; - - const int64_t dr0 = (nr0 + nth0 - 1)/nth0; - const int64_t dr1 = (nr1 + nth1 - 1)/nth1; - - const int64_t ir010 = dr0*ith0; - const int64_t ir011 = MIN(ir010 + dr0, nr0); - - const int64_t ir110 = dr1*ith1; - const int64_t ir111 = MIN(ir110 + dr1, nr1); - - //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); - - // threads with no work simply yield (not sure if it helps) - if (ir010 >= ir011 || ir110 >= ir111) { - sched_yield(); - return; - } +#ifdef GGML_PERF + int chunks_executed = 0; + UNUSED(chunks_executed); +#endif - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const int64_t nr0 = ne0; - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; + // This is the size of the rest of the dimensions of the result + const int64_t nr1 = ne1 * ne2 * ne3; // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t nrc = vec_dot_num_rows; + int64_t num_rows_per_vec_dot = vec_dot_num_rows; // TODO: currently the mmla kernels support only even numbered rows/cols. // this check can be removed once they are extended to support odd numbered rows/cols too if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - nrc = 1; + num_rows_per_vec_dot = 1; } - const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; + // Now select a reasonable chunk size. + int chunk_size = 16; - // attempt to reduce false-sharing (does not seem to make a difference) - // 16 * 2, accounting for mmla kernels - float tmp[32]; + // We need to step up the size if it's small + if (nr0 == 1 || nr1 == 1) { + chunk_size = 64; + } - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) { - const int64_t i13 = (ir1/(ne12*ne1)); - const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; - const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); + // distribute the work across the inner or outer loop based on which one is larger + // The number of chunks in the 0/1 dim. + // CEIL(nr0/chunk_size) + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; - // broadcast src0 into src1 - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; + // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. + // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 + // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. + if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { + // distribute the thread work across the inner or outer loop based on which one is larger + nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + } - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; + // The number of elements in each chunk + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03); + //if (ith == 0) + // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size - : (i11*nb11 + i12*nb12 + i13*nb13)); - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) { - vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc); - } + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - for (int cn = 0; cn < nrc; ++cn) { - memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); - } - } + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); + + ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + +#ifdef GGML_PERF + chunks_executed++; +#endif + + if (nth >= nchunk0 * nchunk1) { + break; } + + current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1); } + +#ifdef GGML_PERF + // These numbers are useful when trying to measure how well the threading scheduling works. + //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1; + //float time = (ggml_perf_time_us() - t0); + //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed); +#endif } // ggml_compute_forward_mul_mat_id @@ -17358,7 +17471,7 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) { GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { @@ -17456,7 +17569,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT: { - ggml_compute_forward_mul_mat(params, tensor); + ggml_compute_forward_mul_mat(params, tensor, state); } break; case GGML_OP_MUL_MAT_ID: { @@ -19072,8 +19185,6 @@ typedef int ggml_lock_t; #define GGML_LOCK_INITIALIZER 0 -typedef pthread_t ggml_thread_t; - #define ggml_thread_create pthread_create #define ggml_thread_join pthread_join @@ -19099,8 +19210,6 @@ typedef int ggml_lock_t; #define GGML_LOCK_INITIALIZER 0 -typedef pthread_t ggml_thread_t; - #define ggml_thread_create pthread_create #define ggml_thread_join pthread_join @@ -19180,31 +19289,6 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } static void clear_numa_thread_affinity(void) {} #endif -struct ggml_compute_state_shared { - const struct ggml_cgraph * cgraph; - const struct ggml_cplan * cplan; - - int64_t perf_node_start_cycles; - int64_t perf_node_start_time_us; - - const int n_threads; - - // synchronization primitives - atomic_int n_active; // num active threads - atomic_int node_n; // active graph node - atomic_int node_task; // active graph node task phase - - ggml_abort_callback abort_callback; // abort ggml_graph_compute when true - void * abort_callback_data; -}; - -struct ggml_compute_state { - ggml_thread_t thrd; - int ith; - struct ggml_compute_state_shared * shared; - enum ggml_status ec; -}; - static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; @@ -19477,6 +19561,10 @@ static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_comput * node_n = atomic_load(&state->shared->node_n); if (* node_n != last_node_n) break; +#if defined(__SSE3__) + // Tell the processor we're spinning. It's a processor hint for spinlocks. + _mm_pause(); +#endif } } @@ -19491,6 +19579,10 @@ static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_co * task_phase = atomic_load(&state->shared->node_task); if (* task_phase != last_task_phase) break; +#if defined(__SSE3__) + // Tell the processor we're spinning. It's a processor hint for spinlocks. + _mm_pause(); +#endif } } @@ -19530,7 +19622,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_tensor * node = cgraph->nodes[node_n]; if (GGML_OP_HAS_FINALIZE[node->op]) { params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } ggml_graph_compute_perf_stats_node(node, state->shared); } @@ -19550,17 +19642,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /* INIT */ if (GGML_OP_HAS_INIT[node->op]) { params.type = GGML_TASK_TYPE_INIT; - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, // they do something more efficient than spinning (?) params.type = GGML_TASK_TYPE_COMPUTE; - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); if (GGML_OP_HAS_FINALIZE[node->op]) { params.type = GGML_TASK_TYPE_FINALIZE; - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } ggml_graph_compute_perf_stats_node(node, state->shared); @@ -19599,7 +19691,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (state->ith < n_tasks) { if (GGML_OP_HAS_INIT[node->op]) { - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } } @@ -19620,7 +19712,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (state->ith < n_tasks) { params.type = GGML_TASK_TYPE_COMPUTE; - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { @@ -19871,6 +19963,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl /*.node_task =*/ GGML_TASK_TYPE_FINALIZE, /*.abort_callback =*/ NULL, /*.abort_callback_data =*/ NULL, + /*.current_chunk; =*/ 0, }; struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);